gptj.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import re
  4. from collections import OrderedDict
  5. import torch
  6. import torch.nn.functional as F
  7. from transformers import GPT2Config, GPTJConfig
  8. def remap_state_dict_hf_gptj(state_dict, config):
  9. def key_mapping_layers(key):
  10. return re.sub(r"^transformer.h.", "transformer.layers.", key)
  11. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  12. # Word embedding
  13. def key_mapping_emb(key):
  14. return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
  15. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  16. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  17. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  18. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  19. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  20. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  21. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  22. )
  23. if getattr(config, "tie_word_embeddings"):
  24. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  25. else:
  26. output_embeddings = state_dict.pop("lm_head.weight")
  27. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  28. state_dict["lm_head.weight"] = F.pad(
  29. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  30. )
  31. output_embeddings_bias = state_dict.pop("lm_head.bias")
  32. state_dict["lm_head.bias"] = F.pad(
  33. output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
  34. )
  35. # LayerNorm
  36. def key_mapping_ln(key):
  37. return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
  38. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  39. # MLP
  40. def key_mapping_mlp(key):
  41. key = re.sub(
  42. r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
  43. )
  44. key = re.sub(
  45. r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
  46. )
  47. return key
  48. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  49. # Attention
  50. for l in range(config.n_layer):
  51. Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
  52. Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
  53. Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
  54. state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
  55. # We don't store these biases
  56. state_dict.pop(f"transformer.layers.{l}.attn.bias")
  57. state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
  58. def key_mapping_attn(key):
  59. return re.sub(
  60. r"^transformer.layers.(\d+).attn.out_proj.",
  61. r"transformer.layers.\1.mixer.out_proj.",
  62. key,
  63. )
  64. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  65. return state_dict
  66. def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
  67. headdim = gptj_config.n_embd // gptj_config.n_head
  68. return GPT2Config(
  69. vocab_size=gptj_config.vocab_size,
  70. n_positions=0, # No absolute position embedding
  71. n_embd=gptj_config.n_embd,
  72. n_layer=gptj_config.n_layer,
  73. n_head=gptj_config.n_head,
  74. n_inner=gptj_config.n_inner,
  75. activation_function=gptj_config.activation_function,
  76. resid_pdrop=gptj_config.resid_pdrop,
  77. embd_pdrop=gptj_config.embd_pdrop,
  78. attn_pdrop=gptj_config.attn_pdrop,
  79. layer_norm_epsilon=gptj_config.layer_norm_epsilon,
  80. initializer_range=gptj_config.initializer_range,
  81. bos_token_id=gptj_config.bos_token_id,
  82. eos_token_id=gptj_config.eos_token_id,
  83. # These are new arguments not in the original GPT2Config
  84. prenorm=True,
  85. parallel_block=True,
  86. parallel_block_tied_norm=True,
  87. rotary_emb_fraction=gptj_config.rotary_dim / headdim,
  88. rotary_emb_interleaved=True,
  89. tie_word_embeddings=False,
  90. qkv_proj_bias=False,
  91. out_proj_bias=False,
  92. lm_head_bias=True,
  93. )