1
0

bigcode.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import math
  2. import re
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn.functional as F
  6. from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
  7. def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
  8. """
  9. Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
  10. """
  11. # Word embedding and position embedding
  12. def key_mapping_pos_emb(key):
  13. return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
  14. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  15. word_embeddings = state_dict.pop("transformer.wte.weight")
  16. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  17. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  18. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  19. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  20. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  21. )
  22. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  23. # LayerNorm
  24. def key_mapping_ln(key):
  25. key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
  26. key = re.sub(
  27. r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
  28. r"transformer.layers.\1.norm\2.\3",
  29. key,
  30. )
  31. return key
  32. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  33. def key_mapping_mlp(key):
  34. key = re.sub(
  35. r"^transformer.h.(\d+).mlp.c_fc.weight",
  36. r"transformer.layers.\1.mlp.fc1.weight",
  37. key,
  38. )
  39. key = re.sub(
  40. r"^transformer.h.(\d+).mlp.c_proj.weight",
  41. r"transformer.layers.\1.mlp.fc2.weight",
  42. key,
  43. )
  44. key = re.sub(
  45. r"^transformer.h.(\d+).mlp.c_fc.bias",
  46. r"transformer.layers.\1.mlp.fc1.bias",
  47. key,
  48. )
  49. key = re.sub(
  50. r"^transformer.h.(\d+).mlp.c_proj.bias",
  51. r"transformer.layers.\1.mlp.fc2.bias",
  52. key,
  53. )
  54. return key
  55. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  56. # TODO: add support for multi-head attention
  57. assert config.multi_query, "Only multi-query attention is supported"
  58. # Attention
  59. for d in range(config.num_hidden_layers):
  60. embed_dim = config.n_embd
  61. head_dim = embed_dim // config.n_head
  62. c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
  63. # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
  64. # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
  65. # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
  66. # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
  67. q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
  68. # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
  69. k = torch.tile(k, (config.n_head, 1))
  70. v = torch.tile(v, (config.n_head, 1))
  71. state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
  72. # same deal with the bias
  73. c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
  74. # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
  75. q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
  76. # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
  77. k = torch.tile(k, (config.n_head,))
  78. v = torch.tile(v, (config.n_head,))
  79. state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
  80. def key_mapping_attn(key):
  81. key = re.sub(
  82. r"^transformer.h.(\d+).attn.c_proj.weight",
  83. r"transformer.layers.\1.mixer.out_proj.weight",
  84. key,
  85. )
  86. key = re.sub(
  87. r"^transformer.h.(\d+).attn.c_proj.bias",
  88. r"transformer.layers.\1.mixer.out_proj.bias",
  89. key,
  90. )
  91. return key
  92. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  93. return state_dict
  94. def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
  95. """
  96. Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
  97. This function is meant to be the inverse of remap_state_dict_hf_bigcode.
  98. """
  99. # Word embedding and position embeddings
  100. def inv_key_mapping_pos_emb(key):
  101. return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
  102. state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  103. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  104. word_embeddings = word_embeddings[:, : config.vocab_size]
  105. state_dict["transformer.wte.weight"] = word_embeddings
  106. state_dict["lm_head.weight"] = word_embeddings
  107. # LayerNorm
  108. def inv_key_mapping_ln(key):
  109. key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
  110. key = re.sub(
  111. r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
  112. r"transformer.h.\1.ln_\2.\3",
  113. key,
  114. )
  115. return key
  116. state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
  117. # MLPs
  118. def inv_key_mapping_mlp(key):
  119. key = re.sub(
  120. r"^transformer.layers.(\d+).mlp.fc1.weight",
  121. r"transformer.h.\1.mlp.c_fc.weight",
  122. key,
  123. )
  124. key = re.sub(
  125. r"^transformer.layers.(\d+).mlp.fc2.weight",
  126. r"transformer.h.\1.mlp.c_proj.weight",
  127. key,
  128. )
  129. key = re.sub(
  130. r"^transformer.layers.(\d+).mlp.fc1.bias",
  131. r"transformer.h.\1.mlp.c_fc.bias",
  132. key,
  133. )
  134. key = re.sub(
  135. r"^transformer.layers.(\d+).mlp.fc2.bias",
  136. r"transformer.h.\1.mlp.c_proj.bias",
  137. key,
  138. )
  139. return key
  140. state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
  141. # Attention
  142. for d in range(config.num_hidden_layers):
  143. embed_dim = config.n_embd
  144. head_dim = embed_dim // config.n_head
  145. Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
  146. q, k, v = torch.split(
  147. Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
  148. )
  149. c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
  150. state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
  151. # Same deal with the bias
  152. Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
  153. q, k, v = torch.split(
  154. Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
  155. )
  156. c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
  157. state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
  158. def inv_key_mapping_attn(key):
  159. key = re.sub(
  160. r"^transformer.layers.(\d+).mixer.out_proj.weight",
  161. r"transformer.h.\1.attn.c_proj.weight",
  162. key,
  163. )
  164. key = re.sub(
  165. r"^transformer.layers.(\d+).mixer.out_proj.bias",
  166. r"transformer.h.\1.attn.c_proj.bias",
  167. key,
  168. )
  169. return key
  170. state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
  171. return state_dict
  172. def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
  173. return GPT2Config(
  174. activation_function=bigcode_config.activation_function,
  175. attn_pdrop=bigcode_config.attn_pdrop,
  176. bos_token_id=bigcode_config.bos_token_id,
  177. embd_pdrop=bigcode_config.embd_pdrop,
  178. eos_token_id=bigcode_config.eos_token_id,
  179. initializer_range=bigcode_config.initializer_range,
  180. layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
  181. max_batch_size=bigcode_config.max_batch_size,
  182. max_sequence_length=bigcode_config.max_sequence_length,
  183. model_type=bigcode_config.model_type,
  184. multi_query=bigcode_config.multi_query,
  185. n_embd=bigcode_config.n_embd,
  186. n_head=bigcode_config.n_head,
  187. n_inner=bigcode_config.n_inner,
  188. n_layer=bigcode_config.n_layer,
  189. n_positions=bigcode_config.n_positions,
  190. resid_pdrop=bigcode_config.resid_pdrop,
  191. scale_attn_weights=bigcode_config.scale_attn_weights,
  192. summary_activation=bigcode_config.summary_activation,
  193. summary_first_dropout=bigcode_config.summary_first_dropout,
  194. summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
  195. summary_type=bigcode_config.summary_type,
  196. summary_use_proj=bigcode_config.summary_use_proj,
  197. use_cache=bigcode_config.use_cache,
  198. vocab_size=bigcode_config.vocab_size,
  199. )