opt.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import re
  4. from collections import OrderedDict
  5. import torch
  6. import torch.nn.functional as F
  7. from transformers import GPT2Config, OPTConfig
  8. def remap_state_dict_hf_opt(state_dict, config):
  9. def key_mapping_model(key):
  10. key = re.sub(r"^model.decoder.", "transformer.", key)
  11. # The OPT-350m model uses '^decoder' instead of '^model.decoder'
  12. key = re.sub(r"^decoder.", "transformer.", key)
  13. return key
  14. state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
  15. # Word embedding and position embedding
  16. def key_mapping_emb(key):
  17. key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
  18. # The OPT-350m model uses has project_in and project_out
  19. key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
  20. key = re.sub(r"^transformer.project_out.", "project_out.", key)
  21. key = re.sub(
  22. r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
  23. )
  24. return key
  25. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  26. # OPT uses the first 2 indices of pos_emb for padding tokens
  27. pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
  28. state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
  29. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  30. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  31. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  32. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  33. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  34. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  35. )
  36. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  37. # LayerNorm
  38. def key_mapping_ln(key):
  39. key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
  40. # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
  41. key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
  42. key = re.sub(
  43. r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
  44. )
  45. key = re.sub(
  46. r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
  47. )
  48. return key
  49. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  50. # MLP
  51. def key_mapping_mlp(key):
  52. return re.sub(
  53. r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
  54. )
  55. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  56. # Attention
  57. for l in range(config.n_layer):
  58. Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
  59. Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
  60. Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
  61. bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
  62. bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
  63. bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
  64. state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
  65. state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
  66. def key_mapping_attn(key):
  67. return re.sub(
  68. r"^transformer.layers.(\d+).self_attn.out_proj.",
  69. r"transformer.layers.\1.mixer.out_proj.",
  70. key,
  71. )
  72. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  73. return state_dict
  74. def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
  75. assert opt_config.layerdrop == 0.0
  76. assert opt_config.layer_norm_elementwise_affine
  77. word_embed_proj_dim = (
  78. None
  79. if opt_config.word_embed_proj_dim == opt_config.hidden_size
  80. else opt_config.word_embed_proj_dim
  81. )
  82. return GPT2Config(
  83. vocab_size=opt_config.vocab_size,
  84. n_positions=opt_config.max_position_embeddings,
  85. n_embd=opt_config.hidden_size,
  86. n_layer=opt_config.num_hidden_layers,
  87. n_head=opt_config.num_attention_heads,
  88. n_inner=opt_config.ffn_dim,
  89. activation_function=opt_config.activation_function,
  90. resid_pdrop=opt_config.dropout,
  91. # HF's implementation of OPT doesn't seem to have embedding dropout
  92. embd_pdrop=opt_config.dropout,
  93. attn_pdrop=opt_config.attention_dropout,
  94. initializer_range=opt_config.init_std,
  95. bos_token_id=opt_config.bos_token_id,
  96. eos_token_id=opt_config.eos_token_id,
  97. # These are new arguments not in the original GPT2Config
  98. prenorm=opt_config.do_layer_norm_before,
  99. word_embed_proj_dim=word_embed_proj_dim,
  100. )