baichuan.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) 2023, GGGGGGXY, Tri Dao.
  2. import math
  3. import json
  4. import re
  5. from pathlib import Path
  6. from collections import OrderedDict
  7. import torch
  8. import torch.nn.functional as F
  9. from einops import rearrange
  10. from transformers import GPT2Config, AutoConfig, PretrainedConfig
  11. def remap_state_dict_hf_baichuan(state_dict, config):
  12. def key_mapping_layers(key):
  13. return re.sub(r"^model.", "transformer.", key)
  14. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  15. # Word embedding
  16. def key_mapping_emb(key):
  17. return re.sub(
  18. r"^transformer.embed_tokens.",
  19. "transformer.embeddings.word_embeddings.",
  20. key,
  21. )
  22. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  23. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  24. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  25. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  26. vocab_size = (
  27. math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
  28. * pad_vocab_size_multiple
  29. )
  30. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  31. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  32. )
  33. if getattr(config, "tie_word_embeddings"):
  34. state_dict["lm_head.weight"] = state_dict[
  35. "transformer.embeddings.word_embeddings.weight"
  36. ]
  37. else:
  38. output_embeddings = state_dict.pop("lm_head.weight")
  39. # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings
  40. # differently.
  41. vocab_size = (
  42. math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
  43. * pad_vocab_size_multiple
  44. )
  45. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  46. state_dict["lm_head.weight"] = F.pad(
  47. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  48. )
  49. # LayerNorm
  50. def key_mapping_ln(key):
  51. key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
  52. key = re.sub(
  53. r"^transformer.layers.(\d+).input_layernorm.",
  54. r"transformer.layers.\1.norm1.",
  55. key,
  56. )
  57. key = re.sub(
  58. r"^transformer.layers.(\d+).post_attention_layernorm.",
  59. r"transformer.layers.\1.norm2.",
  60. key,
  61. )
  62. return key
  63. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  64. # MLP
  65. for l in range(config.n_layer):
  66. w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight")
  67. w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight")
  68. # Our ordering is different
  69. state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat(
  70. [w3, w1], dim=0
  71. )
  72. def key_mapping_mlp(key):
  73. return re.sub(
  74. r"^transformer.layers.(\d+).mlp.down_proj.",
  75. r"transformer.layers.\1.mlp.fc2.",
  76. key,
  77. )
  78. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  79. # Attention
  80. def key_mapping_attn(key):
  81. key = re.sub(
  82. r"^transformer.layers.(\d+).self_attn.W_pack.",
  83. r"transformer.layers.\1.mixer.Wqkv.",
  84. key,
  85. )
  86. key = re.sub(
  87. r"^transformer.layers.(\d+).self_attn.o_proj.",
  88. r"transformer.layers.\1.mixer.out_proj.",
  89. key,
  90. )
  91. return key
  92. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  93. for l in range(config.n_layer):
  94. # pop rotary_emb.inv_freq from state dict
  95. state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None)
  96. return state_dict
  97. def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
  98. # HACK: the config doesn't have say whether it's rotary or alibi.
  99. # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
  100. # HACK: the config doesn't have say whether it uses norm head.
  101. # So we have to infer from the vocab size
  102. # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
  103. use_rotary = baichuan_config.hidden_size < 5000
  104. return GPT2Config(
  105. vocab_size=baichuan_config.vocab_size,
  106. n_positions=0, # No absolute position embedding
  107. n_embd=baichuan_config.hidden_size,
  108. n_layer=baichuan_config.num_hidden_layers,
  109. n_head=baichuan_config.num_attention_heads,
  110. n_inner=baichuan_config.intermediate_size,
  111. activation_function="swiglu", # Hardcode since HF calls it 'silu'
  112. # baichuan doesn't have dropout, idk if it's because they only release the inference code
  113. resid_pdrop=0.0,
  114. embd_pdrop=0.0,
  115. attn_pdrop=0.0,
  116. layer_norm_epsilon=baichuan_config.rms_norm_eps,
  117. initializer_range=baichuan_config.initializer_range,
  118. bos_token_id=baichuan_config.bos_token_id,
  119. eos_token_id=baichuan_config.eos_token_id,
  120. # These are new arguments not in the original GPT2Config
  121. pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything
  122. rms_norm=True,
  123. rotary_emb_fraction=1.0 if use_rotary else 0.0,
  124. rotary_emb_interleaved=False,
  125. use_alibi=not use_rotary,
  126. use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
  127. tie_word_embeddings=False,
  128. norm_head=baichuan_config.vocab_size > 70000,
  129. qkv_proj_bias=False,
  130. out_proj_bias=False,
  131. mlp_fc1_bias=False,
  132. mlp_fc2_bias=False,
  133. )