gpt_neox.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import re
  4. from collections import OrderedDict
  5. import torch
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. from transformers import GPT2Config, GPTNeoXConfig
  9. def remap_state_dict_hf_gpt_neox(state_dict, config):
  10. def key_mapping_layers(key):
  11. return re.sub(r"^gpt_neox.", "transformer.", key)
  12. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  13. # Word embedding
  14. def key_mapping_emb(key):
  15. return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
  16. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  17. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  18. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  19. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  20. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  21. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  22. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  23. )
  24. if getattr(config, "tie_word_embeddings", False):
  25. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  26. else:
  27. output_embeddings = state_dict.pop("embed_out.weight")
  28. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  29. state_dict["lm_head.weight"] = F.pad(
  30. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  31. )
  32. # LayerNorm
  33. def key_mapping_ln(key):
  34. key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
  35. key = re.sub(
  36. r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
  37. )
  38. key = re.sub(
  39. r"^transformer.layers.(\d+).post_attention_layernorm.",
  40. r"transformer.layers.\1.norm2.",
  41. key,
  42. )
  43. return key
  44. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  45. # MLP
  46. def key_mapping_mlp(key):
  47. key = re.sub(
  48. r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
  49. )
  50. key = re.sub(
  51. r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
  52. )
  53. return key
  54. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  55. # Attention
  56. for l in range(config.n_layer):
  57. # We don't store these biases
  58. state_dict.pop(f"transformer.layers.{l}.attention.bias")
  59. state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
  60. # We don't store these
  61. state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None)
  62. # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
  63. # while we store Wqkv as ((3 nheads headdim), hidden_dim)
  64. headdim = config.hidden_size // config.num_attention_heads
  65. Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
  66. state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
  67. Wqkv,
  68. "(nheads three headdim) ... -> (three nheads headdim) ...",
  69. three=3,
  70. headdim=headdim,
  71. )
  72. bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
  73. state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
  74. bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
  75. )
  76. def key_mapping_attn(key):
  77. key = re.sub(
  78. r"^transformer.layers.(\d+).attention.dense.",
  79. r"transformer.layers.\1.mixer.out_proj.",
  80. key,
  81. )
  82. return key
  83. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  84. return state_dict
  85. def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
  86. assert gpt_neox_config.rotary_emb_base == 10000
  87. return GPT2Config(
  88. vocab_size=gpt_neox_config.vocab_size,
  89. n_positions=0, # No absolute position embedding
  90. n_embd=gpt_neox_config.hidden_size,
  91. n_layer=gpt_neox_config.num_hidden_layers,
  92. n_head=gpt_neox_config.num_attention_heads,
  93. n_inner=gpt_neox_config.intermediate_size,
  94. activation_function=gpt_neox_config.hidden_act,
  95. resid_pdrop=0.0, # No dropout
  96. embd_pdrop=0.0,
  97. attn_pdrop=0.0,
  98. layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
  99. initializer_range=gpt_neox_config.initializer_range,
  100. bos_token_id=gpt_neox_config.bos_token_id,
  101. eos_token_id=gpt_neox_config.eos_token_id,
  102. # These are new arguments not in the original GPT2Config
  103. prenorm=True,
  104. parallel_block=gpt_neox_config.use_parallel_residual,
  105. parallel_block_tied_norm=False,
  106. rotary_emb_fraction=gpt_neox_config.rotary_pct,
  107. tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
  108. )