1
0

llama.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # Copyright (c) 2023, Tri Dao.
  2. import json
  3. import math
  4. import os
  5. import re
  6. from collections import OrderedDict
  7. from pathlib import Path
  8. from typing import Dict, List, Union
  9. import torch
  10. import torch.nn.functional as F
  11. from sentencepiece import SentencePieceProcessor
  12. from transformers import GPT2Config, LlamaConfig
  13. from einops import rearrange
  14. def remap_state_dict_meta_llama(
  15. state_dict: Dict[str, torch.Tensor], config: GPT2Config
  16. ) -> Dict[str, torch.Tensor]:
  17. """Convert the state_dict in Meta format to standard GPT format.
  18. This function modifies state_dict in place.
  19. """
  20. def key_mapping_layers(key):
  21. return f"transformer.{key}" if not key.startswith("output.") else key
  22. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  23. # Word embedding
  24. def key_mapping_emb(key):
  25. return re.sub(
  26. r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
  27. )
  28. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  29. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  30. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  31. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  32. vocab_size = (
  33. math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
  34. )
  35. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  36. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  37. )
  38. if getattr(config, "tie_word_embeddings"):
  39. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  40. else:
  41. output_embeddings = state_dict.pop("output.weight")
  42. # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
  43. # differently.
  44. vocab_size = (
  45. math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
  46. * pad_vocab_size_multiple
  47. )
  48. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  49. state_dict["lm_head.weight"] = F.pad(
  50. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  51. )
  52. # LayerNorm
  53. def key_mapping_ln(key):
  54. key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
  55. key = re.sub(
  56. r"^transformer.layers.(\d+).attention_norm.",
  57. r"transformer.layers.\1.norm1.",
  58. key,
  59. )
  60. key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
  61. return key
  62. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  63. # MLP
  64. for l in range(config.n_layer):
  65. w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
  66. w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
  67. # Our ordering is different
  68. state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
  69. def key_mapping_mlp(key):
  70. return re.sub(
  71. r"^transformer.layers.(\d+).feed_forward.w2.",
  72. r"transformer.layers.\1.mlp.fc2.",
  73. key,
  74. )
  75. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  76. # Attention
  77. for l in range(config.n_layer):
  78. Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
  79. Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
  80. Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
  81. state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
  82. # We don't store these
  83. state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
  84. def key_mapping_attn(key):
  85. return re.sub(
  86. r"^transformer.layers.(\d+).attention.wo.",
  87. r"transformer.layers.\1.mixer.out_proj.",
  88. key,
  89. )
  90. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  91. state_dict.pop("transformer.rope.freqs", None)
  92. return state_dict
  93. def remap_state_dict_hf_llama(
  94. state_dict: Dict[str, torch.Tensor], config: GPT2Config
  95. ) -> Dict[str, torch.Tensor]:
  96. """Convert the state_dict in Hugging Face format to standard GPT format.
  97. This function modifies state_dict in place.
  98. """
  99. # Embedding
  100. def key_mapping_emb(key):
  101. return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
  102. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  103. word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
  104. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  105. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  106. vocab_size = (
  107. math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
  108. )
  109. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  110. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  111. )
  112. # LM head
  113. if getattr(config, "tie_word_embeddings"):
  114. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  115. else:
  116. output_embeddings = state_dict.pop("lm_head.weight")
  117. # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
  118. # differently.
  119. vocab_size = (
  120. math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
  121. * pad_vocab_size_multiple
  122. )
  123. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  124. state_dict["lm_head.weight"] = F.pad(
  125. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  126. )
  127. # MLP
  128. for l in range(config.n_layer):
  129. # Fusing weights this way based on difference in the following:
  130. # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
  131. # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
  132. w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
  133. w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
  134. state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
  135. def key_mapping_mlp(key):
  136. return re.sub(
  137. r"^model.layers.(\d+).mlp.down_proj.",
  138. r"transformer.layers.\1.mlp.fc2.",
  139. key,
  140. )
  141. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  142. # LayerNorm
  143. def key_mapping_ln(key):
  144. key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
  145. key = re.sub(
  146. r"^model.layers.(\d+).input_layernorm.",
  147. r"transformer.layers.\1.norm1.",
  148. key,
  149. )
  150. key = re.sub(
  151. r"^model.layers.(\d+).post_attention_layernorm.",
  152. r"transformer.layers.\1.norm2.",
  153. key,
  154. )
  155. return key
  156. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  157. def inv_permute(w):
  158. # Inverse of permute implemented in:
  159. # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
  160. return rearrange(
  161. w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2
  162. )
  163. # Attention
  164. for l in range(config.n_layer):
  165. Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
  166. Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
  167. Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
  168. state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
  169. [inv_permute(Wq), inv_permute(Wk), Wv], dim=0
  170. )
  171. # We don't store these
  172. state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
  173. def key_mapping_attn(key):
  174. return re.sub(
  175. r"^model.layers.(\d+).self_attn.o_proj.",
  176. r"transformer.layers.\1.mixer.out_proj.",
  177. key,
  178. )
  179. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  180. return state_dict
  181. def inv_remap_state_dict_hf_llama(
  182. state_dict: Dict[str, torch.Tensor], config: GPT2Config
  183. ) -> Dict[str, torch.Tensor]:
  184. """Convert the state_dict in standard GPT format to Hugging Face format.
  185. This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
  186. multiplier pad in the embedding and lm_head. That is if the original embedding
  187. isn't a multiple of pad_vocab_size_multiple, then
  188. inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
  189. This function modifies state_dict in place.
  190. """
  191. # Embedding
  192. def key_mapping_emb(key):
  193. return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)
  194. state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
  195. word_embeddings = state_dict.pop("model.embed_tokens.weight")
  196. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  197. vocab_size = (
  198. math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
  199. )
  200. state_dict["model.embed_tokens.weight"] = F.pad(
  201. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  202. )
  203. # LM head
  204. if getattr(config, "tie_word_embeddings"):
  205. state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
  206. else:
  207. output_embeddings = state_dict.pop("lm_head.weight")
  208. vocab_size = (
  209. math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
  210. * pad_vocab_size_multiple
  211. )
  212. state_dict["lm_head.weight"] = F.pad(
  213. output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
  214. )
  215. # MLP
  216. for l in range(config.n_layer):
  217. w3, w1 = torch.chunk(
  218. state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
  219. )
  220. state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
  221. state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
  222. def key_mapping_mlp(key):
  223. return re.sub(
  224. r"^transformer.layers.(\d+).mlp.fc2.",
  225. r"model.layers.\1.mlp.down_proj.",
  226. key,
  227. )
  228. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  229. # LayerNorm
  230. def key_mapping_ln(key):
  231. key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
  232. key = re.sub(
  233. r"^transformer.layers.(\d+).norm1.",
  234. r"model.layers.\1.input_layernorm.",
  235. key,
  236. )
  237. key = re.sub(
  238. r"^transformer.layers.(\d+).norm2.",
  239. r"model.layers.\1.post_attention_layernorm.",
  240. key,
  241. )
  242. return key
  243. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  244. def permute(w):
  245. return rearrange(
  246. w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2
  247. )
  248. n_head = config.n_head
  249. n_head_kv = getattr(config, "n_head_kv", n_head)
  250. embed_dim = config.hidden_size
  251. head_dim = embed_dim // n_head
  252. q_dim = n_head * head_dim
  253. k_dim = v_dim = n_head_kv * head_dim
  254. # Attention
  255. for l in range(config.n_layer):
  256. Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
  257. Wq = Wqkv[:q_dim]
  258. Wk = Wqkv[q_dim : q_dim + k_dim]
  259. Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
  260. state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
  261. state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
  262. state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
  263. state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
  264. def key_mapping_attn(key):
  265. return re.sub(
  266. r"^transformer.layers.(\d+).mixer.out_proj.",
  267. r"model.layers.\1.self_attn.o_proj.",
  268. key,
  269. )
  270. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  271. return state_dict
  272. def config_from_meta_checkpoint(
  273. checkpoint_path: Union[str, os.PathLike], model_name: str
  274. ) -> LlamaConfig:
  275. """Load a LlamaConfig from a checkpoint path."""
  276. with open(Path(checkpoint_path) / model_name / "params.json") as f:
  277. params = json.load(f)
  278. config = LlamaConfig(
  279. hidden_size=params["dim"],
  280. intermediate_size=None,
  281. num_attention_heads=params["n_heads"],
  282. num_hidden_layers=params["n_layers"],
  283. rms_norm_eps=params["norm_eps"],
  284. num_key_value_heads=params.get("n_kv_heads", None),
  285. )
  286. multiple_of = params.get("multiple_of", 1)
  287. ffn_dim_multiplier = params.get("ffn_dim_multiplier", None)
  288. # Compute the hidden dimension of the MLP
  289. # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
  290. intermediate_size = 4 * config.hidden_size
  291. # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
  292. intermediate_size = int(2 * intermediate_size / 3)
  293. # custom dim factor multiplier
  294. if ffn_dim_multiplier is not None:
  295. intermediate_size = int(ffn_dim_multiplier * intermediate_size)
  296. intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
  297. config.intermediate_size = intermediate_size
  298. if "rope_theta" in params:
  299. config.rotary_emb_base = params["rope_theta"]
  300. config.vocab_size = 32000
  301. # some CodeLLaMa have vocab_size 32000, some 32016
  302. # Sadly it's not specified in the `params.json` file :(
  303. tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model"
  304. if tokenizer.is_file():
  305. config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()
  306. return config
  307. def config_from_hf_checkpoint(
  308. checkpoint_path: Union[str, os.PathLike], model_name: str
  309. ) -> LlamaConfig:
  310. return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
  311. def config_from_checkpoint(
  312. checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
  313. ) -> LlamaConfig:
  314. if checkpoint_format == "meta":
  315. return config_from_meta_checkpoint(checkpoint_path, model_name)
  316. else:
  317. return config_from_hf_checkpoint(checkpoint_path, model_name)
  318. def state_dicts_from_checkpoint(
  319. checkpoint_path: Union[str, os.PathLike], model_name: str
  320. ) -> List[dict]:
  321. # Need to sort, otherwise we mess up the ordering and the weights are wrong
  322. return [
  323. torch.load(path, map_location="cpu")
  324. for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
  325. ]
  326. def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
  327. return GPT2Config(
  328. vocab_size=llama_config.vocab_size,
  329. n_positions=0, # No absolute position embedding
  330. n_embd=llama_config.hidden_size,
  331. n_layer=llama_config.num_hidden_layers,
  332. n_head=llama_config.num_attention_heads,
  333. n_inner=llama_config.intermediate_size,
  334. activation_function="swiglu", # Hardcode since HF calls it 'silu'
  335. # Llama doesn't have dropout, idk if it's because they only release the inference code
  336. resid_pdrop=0.0,
  337. embd_pdrop=0.0,
  338. attn_pdrop=0.0,
  339. layer_norm_epsilon=llama_config.rms_norm_eps,
  340. initializer_range=llama_config.initializer_range,
  341. bos_token_id=llama_config.bos_token_id,
  342. eos_token_id=llama_config.eos_token_id,
  343. # These are new arguments not in the original GPT2Config
  344. pad_token_id=llama_config.pad_token_id, # Idk if this does anything
  345. rms_norm=True,
  346. rotary_emb_fraction=1.0,
  347. rotary_emb_interleaved=True,
  348. tie_word_embeddings=False,
  349. qkv_proj_bias=False,
  350. out_proj_bias=False,
  351. mlp_fc1_bias=False,
  352. mlp_fc2_bias=False,
  353. rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0),
  354. n_head_kv=llama_config.num_key_value_heads,
  355. )