falcon.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. import re
  4. from collections import OrderedDict
  5. import torch
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. from transformers import FalconConfig, GPT2Config
  9. def remap_state_dict_hf_falcon(state_dict, config):
  10. def key_mapping_layers(key):
  11. return re.sub(r"^transformer.h.", "transformer.layers.", key)
  12. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  13. # Word embedding
  14. def key_mapping_emb(key):
  15. return re.sub(
  16. r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
  17. )
  18. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  19. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  20. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  21. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  22. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  23. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  24. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  25. )
  26. if getattr(config, "tie_word_embeddings"):
  27. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  28. else:
  29. output_embeddings = state_dict.pop("lm_head.weight")
  30. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  31. state_dict["lm_head.weight"] = F.pad(
  32. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  33. )
  34. output_embeddings_bias = state_dict.pop("lm_head.bias")
  35. state_dict["lm_head.bias"] = F.pad(
  36. output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
  37. )
  38. # LayerNorm
  39. def key_mapping_ln(key):
  40. key = re.sub(
  41. r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
  42. )
  43. key = re.sub(
  44. r"^transformer.layers.(\d+).post_attention_layernorm.",
  45. r"transformer.layers.\1.norm2.",
  46. key,
  47. )
  48. key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
  49. key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
  50. return key
  51. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  52. # MLP
  53. def key_mapping_mlp(key):
  54. key = re.sub(
  55. r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
  56. )
  57. key = re.sub(
  58. r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
  59. )
  60. return key
  61. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  62. def key_mapping_attn(key):
  63. key = re.sub(
  64. r"^transformer.layers.(\d+).self_attention.query_key_value.",
  65. r"transformer.layers.\1.mixer.Wqkv.",
  66. key,
  67. )
  68. key = re.sub(
  69. r"^transformer.layers.(\d+).self_attention.dense.",
  70. r"transformer.layers.\1.mixer.out_proj.",
  71. key,
  72. )
  73. return key
  74. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  75. n_head = config.n_head
  76. n_head_kv = getattr(config, "n_head_kv", 1)
  77. headdim = config.hidden_size // n_head
  78. for l in range(config.n_layer):
  79. # The weights are stored in a different layout compared to our implementation
  80. Wqkv = rearrange(
  81. state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
  82. "(group ratio headdim) ... -> group ratio headdim ...",
  83. ratio=n_head // n_head_kv + 2,
  84. headdim=headdim,
  85. )
  86. Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
  87. Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
  88. Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
  89. state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
  90. return state_dict
  91. def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
  92. # The 40b config uses "n_head_kv" instead of "num_kv_heads"
  93. n_head_kv = getattr(
  94. falcon_config,
  95. "n_head_kv",
  96. 1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
  97. )
  98. # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
  99. # So we have to infer it from the number of heads in the key/value block
  100. parallel_block_tied_norm = n_head_kv == 1
  101. return GPT2Config(
  102. vocab_size=falcon_config.vocab_size,
  103. n_positions=0, # No absolute position embedding
  104. n_embd=falcon_config.hidden_size,
  105. n_layer=falcon_config.n_layer,
  106. n_head=falcon_config.n_head,
  107. n_inner=falcon_config.hidden_size * 4,
  108. activation_function="gelu",
  109. resid_pdrop=falcon_config.hidden_dropout,
  110. embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
  111. attn_pdrop=falcon_config.attention_dropout,
  112. layer_norm_epsilon=falcon_config.layer_norm_epsilon,
  113. initializer_range=falcon_config.initializer_range,
  114. bos_token_id=falcon_config.bos_token_id,
  115. eos_token_id=falcon_config.eos_token_id,
  116. # These are new arguments not in the original GPT2Config
  117. parallel_block=falcon_config.parallel_attn,
  118. n_head_kv=n_head_kv,
  119. parallel_block_tied_norm=parallel_block_tied_norm,
  120. rotary_emb_fraction=1.0,
  121. rotary_emb_interleaved=False,
  122. tie_word_embeddings=True,
  123. qkv_proj_bias=falcon_config.bias,
  124. out_proj_bias=falcon_config.bias,
  125. mlp_fc1_bias=falcon_config.bias,
  126. mlp_fc2_bias=falcon_config.bias,
  127. lm_head_bias=False,
  128. )