btlm.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import json
  4. import re
  5. from pathlib import Path
  6. from collections import OrderedDict
  7. import torch
  8. import torch.nn.functional as F
  9. from einops import rearrange
  10. from transformers import GPT2Config, AutoConfig, PretrainedConfig
  11. def remap_state_dict_hf_btlm(state_dict, config):
  12. # Word embedding and position embedding
  13. def key_mapping_pos_emb(key):
  14. return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
  15. if "transformer.wpe.weight" in state_dict:
  16. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  17. word_embeddings = state_dict.pop("transformer.wte.weight")
  18. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  19. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  20. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  21. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  22. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  23. )
  24. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  25. # LayerNorm
  26. def key_mapping_ln(key):
  27. key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
  28. key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
  29. return key
  30. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  31. # MLP
  32. for d in range(config.num_hidden_layers):
  33. W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight")
  34. W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight")
  35. state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0)
  36. b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias")
  37. b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias")
  38. state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0)
  39. W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight")
  40. state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
  41. def key_mapping_mlp(key):
  42. key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
  43. return key
  44. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  45. # Attention
  46. for d in range(config.num_hidden_layers):
  47. Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
  48. state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
  49. Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight")
  50. state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
  51. state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes
  52. def key_mapping_attn(key):
  53. key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
  54. key = re.sub(
  55. r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
  56. )
  57. return key
  58. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  59. return state_dict
  60. def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:
  61. return GPT2Config(
  62. vocab_size=btlm_config.vocab_size,
  63. n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions,
  64. n_embd=btlm_config.hidden_size,
  65. n_layer=btlm_config.num_hidden_layers,
  66. n_head=btlm_config.num_attention_heads,
  67. n_inner=btlm_config.n_inner,
  68. activation_function=btlm_config.activation_function,
  69. resid_pdrop=btlm_config.resid_pdrop,
  70. embd_pdrop=btlm_config.embd_pdrop,
  71. attn_pdrop=btlm_config.attn_pdrop,
  72. layer_norm_epsilon=btlm_config.layer_norm_epsilon,
  73. initializer_range=btlm_config.initializer_range,
  74. bos_token_id=btlm_config.bos_token_id,
  75. eos_token_id=btlm_config.eos_token_id,
  76. # These are new arguments not in the original GPT2Config
  77. use_alibi=btlm_config.position_embedding_type == "alibi",
  78. use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn
  79. mup_width_scale=btlm_config.mup_width_scale,
  80. mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,
  81. mup_output_multiplier=btlm_config.mup_output_alpha,
  82. mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,
  83. mlp_multiple_of=1,
  84. )