1
0

gpt.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. # Copyright (c) 2024, Tri Dao.
  2. import logging
  3. import math
  4. import re
  5. from collections import OrderedDict, namedtuple
  6. from collections.abc import Sequence
  7. from functools import partial
  8. from typing import Dict, List
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from einops import rearrange
  13. from transformers import GPT2Config
  14. from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
  15. from flash_attn.models.falcon import remap_state_dict_hf_falcon
  16. from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
  17. from flash_attn.models.gptj import remap_state_dict_hf_gptj
  18. from flash_attn.models.llama import remap_state_dict_hf_llama
  19. from flash_attn.models.opt import remap_state_dict_hf_opt
  20. from flash_attn.modules.block import Block, ParallelBlock
  21. from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
  22. from flash_attn.modules.mha import MHA, ParallelMHA
  23. from flash_attn.modules.mlp import (
  24. FusedMLP,
  25. GatedMlp,
  26. Mlp,
  27. ParallelFusedMLP,
  28. ParallelGatedMlp,
  29. ParallelMLP,
  30. )
  31. from flash_attn.ops.activations import sqrelu_fwd
  32. from flash_attn.utils.distributed import (
  33. all_gather,
  34. all_gather_raw,
  35. get_dim_for_local_rank,
  36. sync_shared_params,
  37. )
  38. from flash_attn.utils.generation import GenerationMixin
  39. from flash_attn.utils.pretrained import state_dict_from_pretrained
  40. try:
  41. from flash_attn.ops.fused_dense import ColumnParallelLinear
  42. except ImportError:
  43. ColumnParallelLinear = None
  44. try:
  45. from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
  46. except ImportError:
  47. FusedDenseSqreluDense = None
  48. try:
  49. from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
  50. except ImportError:
  51. layer_norm_fn, RMSNorm = None, None
  52. logger = logging.getLogger(__name__)
  53. def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
  54. factory_kwargs = {"device": device, "dtype": dtype}
  55. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  56. attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0
  57. softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))
  58. softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0)
  59. if config.scale_attn_by_inverse_layer_idx:
  60. assert layer_idx is not None
  61. softmax_scale /= float(layer_idx + 1)
  62. dwconv = getattr(config, "attn_dwconv", False)
  63. if dwconv:
  64. assert process_group is None, "TensorParallel MHA does not support dwconv yet"
  65. qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
  66. out_proj_bias = getattr(config, "out_proj_bias", True)
  67. rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
  68. rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
  69. rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
  70. rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
  71. use_alibi = getattr(config, "use_alibi", False)
  72. window_size = getattr(config, "window_size", (-1, -1))
  73. use_flash_attn = getattr(config, "use_flash_attn", False)
  74. fused_bias_fc = getattr(config, "fused_bias_fc", False)
  75. if not fused_bias_fc:
  76. assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
  77. mha_cls = MHA if process_group is None else ParallelMHA
  78. serial_kwargs = (
  79. {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
  80. )
  81. parallel_kwargs = (
  82. {
  83. "process_group": process_group,
  84. "sequence_parallel": getattr(config, "sequence_parallel", True),
  85. }
  86. if process_group is not None
  87. else {}
  88. )
  89. num_heads_kv = getattr(config, "n_head_kv", None)
  90. mixer_cls = partial(
  91. mha_cls,
  92. num_heads=config.num_attention_heads,
  93. num_heads_kv=num_heads_kv,
  94. qkv_proj_bias=qkv_proj_bias,
  95. out_proj_bias=out_proj_bias,
  96. dropout=config.attn_pdrop,
  97. softmax_scale=softmax_scale,
  98. causal=True,
  99. layer_idx=layer_idx,
  100. rotary_emb_dim=rotary_emb_dim,
  101. rotary_emb_base=rotary_emb_base,
  102. rotary_emb_scale_base=rotary_emb_scale_base,
  103. rotary_emb_interleaved=rotary_emb_interleaved,
  104. use_alibi=use_alibi,
  105. window_size=window_size,
  106. use_flash_attn=use_flash_attn,
  107. **serial_kwargs,
  108. **parallel_kwargs,
  109. **factory_kwargs,
  110. )
  111. return mixer_cls
  112. def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
  113. factory_kwargs = {"device": device, "dtype": dtype}
  114. mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
  115. mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
  116. fused_mlp = getattr(config, "fused_mlp", False)
  117. if fused_mlp:
  118. assert config.activation_function in [
  119. "gelu_new",
  120. "gelu_fast",
  121. "gelu_approx",
  122. "gelu_pytorch_tanh",
  123. "relu",
  124. "sqrelu",
  125. ]
  126. fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
  127. if fused_dense_sqrelu_dense:
  128. assert config.activation_function == "sqrelu", (
  129. "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
  130. )
  131. assert not (fused_dense_sqrelu_dense and fused_mlp)
  132. if not fused_mlp and not fused_dense_sqrelu_dense:
  133. assert config.activation_function in [
  134. "gelu",
  135. "gelu_new",
  136. "gelu_fast",
  137. "gelu_approx",
  138. "gelu_pytorch_tanh",
  139. "relu",
  140. "sqrelu",
  141. "glu",
  142. "swiglu",
  143. "geglu",
  144. ]
  145. if config.activation_function in ["glu", "swiglu", "geglu"]:
  146. activation = (
  147. F.sigmoid
  148. if config.activation_function == "glu"
  149. else (F.silu if config.activation_function == "swiglu" else F.gelu)
  150. )
  151. mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
  152. parallel_kwargs = (
  153. {
  154. "process_group": process_group,
  155. "sequence_parallel": getattr(config, "sequence_parallel", True),
  156. }
  157. if process_group is not None
  158. else {}
  159. )
  160. mlp_multiple_of = getattr(config, "mlp_multiple_of", 128)
  161. mlp_cls = partial(
  162. mlp_cls,
  163. hidden_features=config.n_inner,
  164. activation=activation,
  165. bias1=mlp_fc1_bias,
  166. bias2=mlp_fc2_bias,
  167. multiple_of=mlp_multiple_of,
  168. **parallel_kwargs,
  169. **factory_kwargs,
  170. )
  171. else:
  172. if config.activation_function == "relu":
  173. activation = partial(F.relu, inplace=True)
  174. elif config.activation_function == "sqrelu":
  175. activation = sqrelu_fwd
  176. else:
  177. approximate = (
  178. "tanh"
  179. if config.activation_function
  180. in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
  181. else "none"
  182. )
  183. activation = partial(F.gelu, approximate=approximate)
  184. mlp_cls = Mlp if process_group is None else ParallelMLP
  185. parallel_kwargs = (
  186. {
  187. "process_group": process_group,
  188. "sequence_parallel": getattr(config, "sequence_parallel", True),
  189. }
  190. if process_group is not None
  191. else {}
  192. )
  193. mlp_cls = partial(
  194. mlp_cls,
  195. hidden_features=config.n_inner,
  196. activation=activation,
  197. bias1=mlp_fc1_bias,
  198. bias2=mlp_fc2_bias,
  199. **parallel_kwargs,
  200. **factory_kwargs,
  201. )
  202. else:
  203. mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
  204. # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
  205. if isinstance(mlp_checkpoint_lvl, Sequence):
  206. assert layer_idx is not None
  207. mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
  208. if fused_mlp:
  209. if FusedMLP is None:
  210. raise ImportError("fused_dense is not installed")
  211. activation = (
  212. "gelu_approx"
  213. if config.activation_function
  214. in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
  215. else config.activation_function
  216. )
  217. mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
  218. parallel_kwargs = (
  219. {
  220. "process_group": process_group,
  221. "sequence_parallel": getattr(config, "sequence_parallel", True),
  222. }
  223. if process_group is not None
  224. else {}
  225. )
  226. mlp_cls = partial(
  227. mlp_cls,
  228. hidden_features=config.n_inner,
  229. activation=activation,
  230. checkpoint_lvl=mlp_checkpoint_lvl,
  231. bias1=mlp_fc1_bias,
  232. bias2=mlp_fc2_bias,
  233. **parallel_kwargs,
  234. **factory_kwargs,
  235. )
  236. elif fused_dense_sqrelu_dense:
  237. if process_group is not None:
  238. assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
  239. assert FusedDenseSqreluDense is not None
  240. mlp_cls = partial(
  241. FusedDenseSqreluDense,
  242. hidden_features=config.n_inner,
  243. checkpoint_lvl=mlp_checkpoint_lvl,
  244. **factory_kwargs,
  245. )
  246. else:
  247. raise RuntimeError("MLP type not supported")
  248. return mlp_cls
  249. def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
  250. factory_kwargs = {"device": device, "dtype": dtype}
  251. sequence_parallel = getattr(config, "sequence_parallel", True)
  252. mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
  253. mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
  254. use_rms_norm = getattr(config, "rms_norm", False)
  255. norm_cls = partial(
  256. nn.LayerNorm if not use_rms_norm else RMSNorm,
  257. eps=config.layer_norm_epsilon,
  258. **factory_kwargs,
  259. )
  260. # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
  261. residual_in_fp32 = getattr(config, "residual_in_fp32", False)
  262. resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
  263. prenorm = getattr(config, "prenorm", True)
  264. parallel_block = getattr(config, "parallel_block", False)
  265. if not parallel_block:
  266. block = Block(
  267. config.hidden_size,
  268. mixer_cls,
  269. mlp_cls,
  270. norm_cls=norm_cls,
  271. prenorm=prenorm,
  272. resid_dropout1=resid_dropout1,
  273. resid_dropout2=config.resid_pdrop,
  274. fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
  275. residual_in_fp32=residual_in_fp32,
  276. sequence_parallel=sequence_parallel and process_group is not None,
  277. mark_shared_params=process_group is not None,
  278. )
  279. else:
  280. assert prenorm
  281. block = ParallelBlock(
  282. config.hidden_size,
  283. mixer_cls,
  284. mlp_cls,
  285. norm_cls=norm_cls,
  286. resid_dropout1=resid_dropout1,
  287. resid_dropout2=config.resid_pdrop,
  288. tied_norm=getattr(config, "parallel_block_tied_norm", False),
  289. fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
  290. residual_in_fp32=residual_in_fp32,
  291. sequence_parallel=sequence_parallel and process_group is not None,
  292. mark_shared_params=process_group is not None,
  293. )
  294. block.layer_idx = layer_idx
  295. return block
  296. class GPTPreTrainedModel(nn.Module):
  297. """An abstract class to handle weights initialization and
  298. a simple interface for dowloading and loading pretrained models.
  299. """
  300. def __init__(self, config, *inputs, **kwargs):
  301. super().__init__()
  302. if not isinstance(config, GPT2Config):
  303. raise ValueError(
  304. "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
  305. "To create a model from a Google pretrained model use "
  306. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  307. self.__class__.__name__, self.__class__.__name__
  308. )
  309. )
  310. self.config = config
  311. @classmethod
  312. def from_pretrained(
  313. cls,
  314. model_name,
  315. config,
  316. *args,
  317. strict=True,
  318. device=None,
  319. dtype=None,
  320. world_size=1,
  321. rank=0,
  322. **kwargs,
  323. ):
  324. """
  325. Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
  326. Download and cache the pre-trained model file if needed.
  327. """
  328. # Instantiate model.
  329. model = cls(config, *args, device=device, dtype=dtype, **kwargs)
  330. # Load state_dict in cpu because we already initialized the model in GPU, and we don't
  331. # want extra stuff taking up more GPU memory
  332. state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
  333. if model_name.startswith("gpt2"):
  334. state_dict = remap_state_dict_hf_gpt2(state_dict, config)
  335. elif model_name.startswith("facebook/opt"):
  336. state_dict = remap_state_dict_hf_opt(state_dict, config)
  337. elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith(
  338. "togethercomputer/GPT-JT-"
  339. ):
  340. state_dict = remap_state_dict_hf_gptj(state_dict, config)
  341. elif (
  342. model_name.startswith("EleutherAI/gpt-neox-")
  343. or model_name.startswith("EleutherAI/pythia-")
  344. or model_name.startswith("togethercomputer/RedPajama-INCITE-")
  345. ):
  346. state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
  347. elif model_name.startswith("tiiuae/falcon-"):
  348. state_dict = remap_state_dict_hf_falcon(state_dict, config)
  349. elif model_name.startswith("meta-llama/Llama-"):
  350. state_dict = remap_state_dict_hf_llama(state_dict, config)
  351. elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
  352. state_dict = remap_state_dict_hf_bigcode(state_dict, config)
  353. else:
  354. raise NotImplementedError(f"Model {model_name} not supported")
  355. if world_size > 1:
  356. state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
  357. load_return = model.load_state_dict(state_dict, strict=strict)
  358. logger.info(load_return)
  359. return model
  360. # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
  361. def _init_weights(
  362. module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True
  363. ):
  364. mup_init_scale = math.sqrt(mup_width_scale)
  365. if isinstance(module, nn.Linear):
  366. nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
  367. optim_cfg = getattr(module.weight, "_optim", {})
  368. optim_cfg.update({"lr_multiplier": mup_width_scale})
  369. setattr(module.weight, "_optim", optim_cfg)
  370. if module.bias is not None:
  371. nn.init.zeros_(module.bias)
  372. elif isinstance(module, nn.Embedding):
  373. nn.init.normal_(module.weight, std=initializer_range)
  374. if rescale_prenorm_residual:
  375. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  376. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  377. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  378. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  379. #
  380. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  381. for name, p in module.named_parameters():
  382. if name in ["out_proj.weight", "fc2.weight"]:
  383. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  384. nn.init.normal_(
  385. p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer)
  386. )
  387. class GPTModel(GPTPreTrainedModel):
  388. def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
  389. super().__init__(config)
  390. factory_kwargs = {"device": device, "dtype": dtype}
  391. self.process_group = process_group
  392. self.sequence_parallel = getattr(config, "sequence_parallel", True)
  393. assert config.activation_function in [
  394. "gelu",
  395. "gelu_new",
  396. "gelu_fast",
  397. "gelu_approx",
  398. "gelu_pytorch_tanh",
  399. "relu",
  400. "sqrelu",
  401. "glu",
  402. "swiglu",
  403. "geglu",
  404. ]
  405. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  406. vocab_size = (
  407. math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  408. )
  409. self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0)
  410. # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
  411. self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
  412. # These 2 options are for OPT-350m
  413. self.prenorm = getattr(config, "prenorm", True)
  414. use_rms_norm = getattr(config, "rms_norm", False)
  415. word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
  416. # For GPT-J, GPT-NeoX
  417. self.parallel_block = getattr(config, "parallel_block", False)
  418. if process_group is None:
  419. self.embeddings = GPT2Embeddings(
  420. config.hidden_size,
  421. vocab_size,
  422. config.max_position_embeddings,
  423. word_embed_proj_dim=word_embed_proj_dim,
  424. **factory_kwargs,
  425. )
  426. else:
  427. self.embeddings = ParallelGPT2Embeddings(
  428. config.hidden_size,
  429. vocab_size,
  430. config.max_position_embeddings,
  431. process_group=process_group,
  432. sequence_parallel=self.sequence_parallel,
  433. **factory_kwargs,
  434. )
  435. # We change the order of dropout, residual and layer norm:
  436. # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
  437. # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
  438. # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
  439. # nn.Dropout probabilities are changed.
  440. # This is for performance reason: we can fuse dropout + add + layer_norm.
  441. self.layers = nn.ModuleList(
  442. [
  443. create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
  444. for i in range(config.num_hidden_layers)
  445. ]
  446. )
  447. rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
  448. if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
  449. for layer in self.layers[1:]:
  450. layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb
  451. self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
  452. if self.fused_dropout_add_ln:
  453. if layer_norm_fn is None:
  454. raise ImportError("Triton is not installed")
  455. if self.prenorm:
  456. self.drop_f = nn.Dropout(config.resid_pdrop)
  457. norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
  458. self.ln_f = norm_cls(
  459. config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
  460. )
  461. if process_group is not None:
  462. for p in self.ln_f.parameters():
  463. # Mark the norm parameters as "shared_params" so that we sync their values at init.
  464. p._shared_params = True
  465. # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
  466. if self.sequence_parallel:
  467. p._sequence_parallel = True
  468. self.apply(
  469. partial(
  470. _init_weights,
  471. n_layer=config.num_hidden_layers,
  472. initializer_range=config.initializer_range,
  473. mup_width_scale=getattr(config, "mup_width_scale", 1.0),
  474. )
  475. )
  476. self.tie_weights()
  477. def tie_weights(self):
  478. if self.process_group is not None:
  479. sync_shared_params(self, self.process_group)
  480. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  481. return {
  482. i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
  483. for i, layer in enumerate(self.layers)
  484. }
  485. def forward(self, input_ids, position_ids=None, inference_params=None):
  486. # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
  487. # dimensions so that we can split on it easily, in case of small batch size.
  488. # Only the attention layers need to know the seqlen.
  489. embedding_kwargs = (
  490. {"combine_batch_seqlen_dim": True}
  491. if self.process_group is not None and self.sequence_parallel
  492. else {}
  493. )
  494. hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
  495. if self.embeddings_multiplier != 1.0:
  496. hidden_states = hidden_states * self.embeddings_multiplier
  497. if self.parallel_block:
  498. hidden_states2 = None
  499. residual = None
  500. mixer_kwargs = (
  501. {"seqlen": input_ids.shape[1]}
  502. if self.process_group is not None and self.sequence_parallel
  503. else {}
  504. )
  505. if inference_params is not None:
  506. mixer_kwargs["inference_params"] = inference_params
  507. for layer in self.layers:
  508. if self.prenorm:
  509. if not self.parallel_block:
  510. hidden_states, residual = layer(
  511. hidden_states, residual, mixer_kwargs=mixer_kwargs
  512. )
  513. else:
  514. hidden_states, hidden_states2, residual = layer(
  515. hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
  516. )
  517. else:
  518. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  519. if self.prenorm:
  520. if not self.fused_dropout_add_ln:
  521. dropped = self.drop_f(hidden_states)
  522. if not self.parallel_block:
  523. residual = (dropped + residual) if residual is not None else dropped
  524. else:
  525. dropped2 = self.drop_f(hidden_states2)
  526. residual = (
  527. (residual + dropped + dropped2)
  528. if residual is not None
  529. else dropped + dropped2
  530. )
  531. hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
  532. else:
  533. # Set prenorm=False here since we don't need the residual
  534. hidden_states = layer_norm_fn(
  535. hidden_states,
  536. self.ln_f.weight,
  537. self.ln_f.bias,
  538. residual=residual,
  539. x1=None if not self.parallel_block else hidden_states2,
  540. eps=self.ln_f.eps,
  541. dropout_p=self.drop_f.p if self.training else 0.0,
  542. prenorm=False,
  543. is_rms_norm=isinstance(self.ln_f, RMSNorm)
  544. )
  545. return hidden_states
  546. class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
  547. def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
  548. factory_kwargs = {"device": device, "dtype": dtype}
  549. super().__init__(config)
  550. self.process_group = process_group
  551. self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  552. self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
  553. lm_head_bias = getattr(config, "lm_head_bias", False)
  554. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  555. vocab_size = (
  556. math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  557. )
  558. # This option is for OPT-350m
  559. word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
  560. embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
  561. if word_embed_proj_dim is not None:
  562. self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
  563. else:
  564. self.project_out = None
  565. mup_width_scale = getattr(config, "mup_width_scale", 1.0)
  566. mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0)
  567. self.output_scale = mup_output_multiplier * mup_width_scale
  568. if process_group is None:
  569. self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
  570. else:
  571. if ColumnParallelLinear is None:
  572. raise ImportError("fused_dense_lib is not installed")
  573. self.lm_head = ColumnParallelLinear(
  574. embed_dim,
  575. vocab_size,
  576. process_group,
  577. bias=lm_head_bias,
  578. sequence_parallel=getattr(config, "sequence_parallel", True),
  579. **factory_kwargs,
  580. )
  581. self.norm_head = getattr(config, "norm_head", False)
  582. # Initialize weights and apply final processing
  583. self.apply(
  584. partial(
  585. _init_weights,
  586. n_layer=config.num_hidden_layers,
  587. initializer_range=config.initializer_range,
  588. mup_width_scale=mup_width_scale,
  589. )
  590. )
  591. self.tie_weights()
  592. def tie_weights(self):
  593. if self.tie_word_embeddings:
  594. self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
  595. if self.process_group is not None:
  596. sync_shared_params(self, self.process_group)
  597. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  598. return self.transformer.allocate_inference_cache(
  599. batch_size, max_seqlen, dtype=dtype, **kwargs
  600. )
  601. def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
  602. """
  603. input_ids: (batch, seqlen) int tensor
  604. inference_params: for generation. Adapted from Megatron-LM (and Apex)
  605. https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
  606. num_last_tokens: if > 0, only return the logits for the last n tokens
  607. """
  608. assert (
  609. input_ids.ndim == 2
  610. ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
  611. b, slen = input_ids.shape
  612. hidden_states = self.transformer(
  613. input_ids, position_ids=position_ids, inference_params=inference_params
  614. )
  615. if inference_params is not None:
  616. assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
  617. if num_last_tokens > 0:
  618. hidden_states = hidden_states[:, -num_last_tokens:]
  619. if self.project_out is not None:
  620. hidden_states = self.project_out(hidden_states)
  621. if self.output_scale != 1.0:
  622. hidden_states = hidden_states * self.output_scale
  623. if not self.norm_head:
  624. lm_logits = self.lm_head(hidden_states)
  625. else:
  626. lm_head_weight = F.normalize(self.lm_head.weight)
  627. if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
  628. hidden_states = all_gather(hidden_states, self.lm_head.process_group)
  629. lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
  630. # During inference, we want the full logit for sampling
  631. if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
  632. lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
  633. lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
  634. CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
  635. return CausalLMOutput(logits=lm_logits)
  636. def load_state_dict(self, state_dict, strict=True):
  637. # Remapping from our checkpoints that used a different ordering of layers in the block
  638. # Previous: Attn / MLP -> Dropout -> Add -> LN
  639. # Current: Dropout -> Add -> LN -> Attn / MLP
  640. if "transformer.ln_0.weight" in state_dict:
  641. n_layers = len(self.transformer.layers)
  642. ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
  643. ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
  644. state_dict["transformer.ln_f.weight"] = ln_weight
  645. state_dict["transformer.ln_f.bias"] = ln_bias
  646. for l in reversed(range(n_layers)):
  647. ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
  648. ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
  649. state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
  650. state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
  651. if l > 0:
  652. ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
  653. ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
  654. state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
  655. state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
  656. ln_weight = state_dict.pop("transformer.ln_0.weight")
  657. ln_bias = state_dict.pop("transformer.ln_0.bias")
  658. state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
  659. state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
  660. return super().load_state_dict(state_dict, strict=strict)
  661. def shard_state_dict_tp(state_dict, config, world_size, rank):
  662. """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
  663. with tensor parallel.
  664. This function modifies state_dict in place.
  665. """
  666. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  667. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  668. assert vocab_size % world_size == 0
  669. assert config.hidden_size % world_size == 0
  670. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  671. assert inner_dim % world_size == 0
  672. n_head = config.n_head
  673. n_head_kv = getattr(config, "n_head_kv", n_head)
  674. embed_dim = config.hidden_size
  675. head_dim = embed_dim // n_head
  676. def shard_first_dim(state_dict, key):
  677. if key in state_dict:
  678. x = state_dict[key]
  679. dim = x.shape[0] // world_size
  680. state_dict[key] = x[rank * dim : (rank + 1) * dim]
  681. def shard_last_dim(state_dict, key, multiple_of=1):
  682. if key in state_dict:
  683. x = state_dict[key]
  684. dim_each_rank = [
  685. get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
  686. for local_rank in range(world_size)
  687. ]
  688. beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
  689. state_dict[key] = x[..., beg:end]
  690. def shard_gatedmlp_fc1_dim(state_dict, key):
  691. if key in state_dict:
  692. x = state_dict[key]
  693. dim = x.shape[0] // world_size // 2
  694. state_dict[key] = rearrange(
  695. rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
  696. "two o ... -> (two o) ...",
  697. )
  698. def shard_qkv_headdim(state_dict, key):
  699. if key in state_dict:
  700. n_head_each_rank = [
  701. get_dim_for_local_rank(n_head, world_size, local_rank)
  702. for local_rank in range(world_size)
  703. ]
  704. n_head_kv_each_rank = [
  705. get_dim_for_local_rank(n_head_kv, world_size, local_rank)
  706. for local_rank in range(world_size)
  707. ]
  708. beg_n_head = sum(n_head_each_rank[:rank])
  709. end_n_head = sum(n_head_each_rank[: rank + 1])
  710. beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
  711. end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])
  712. if n_head_kv == n_head:
  713. x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
  714. state_dict[key] = rearrange(
  715. x[:, beg_n_head * head_dim : end_n_head * head_dim],
  716. "three d ... -> (three d) ...",
  717. )
  718. else:
  719. x = rearrange(
  720. state_dict[key],
  721. "(nheadqkv headdim) ... -> nheadqkv headdim ...",
  722. nheadqkv=n_head + 2 * n_head_kv,
  723. )
  724. state_dict[key] = rearrange(
  725. torch.cat(
  726. [
  727. x[beg_n_head:end_n_head],
  728. x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
  729. x[
  730. n_head
  731. + n_head_kv
  732. + beg_n_head_kv : n_head
  733. + n_head_kv
  734. + end_n_head_kv
  735. ],
  736. ],
  737. dim=0,
  738. ),
  739. "nheadqkv headdim ... -> (nheadqkv headdim) ...",
  740. )
  741. shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
  742. if "lm_head.weight" in state_dict:
  743. shard_first_dim(state_dict, "lm_head.weight")
  744. if "transformer.embeddings.position_embeddings.weight" in state_dict:
  745. shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
  746. for i in range(config.num_hidden_layers):
  747. shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
  748. shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
  749. shard_last_dim(
  750. state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
  751. )
  752. if rank != 0:
  753. state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
  754. if config.activation_function in ["glu", "swiglu", "geglu"]:
  755. shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
  756. shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
  757. else:
  758. shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
  759. shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
  760. shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
  761. if rank != 0:
  762. state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
  763. return state_dict
  764. def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
  765. """Convert the list of sharded state_dict of a GPT model with tensor parallel to
  766. the state_dict of a standard GPT model.
  767. This function is meant to be the "reverse" of shard_state_dict_tp.
  768. Precondition:
  769. - state_dicts should be ordered in the same way as the shards were created.
  770. """
  771. world_size = len(state_dicts)
  772. keys = state_dicts[0].keys()
  773. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  774. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  775. assert vocab_size % world_size == 0
  776. assert config.hidden_size % world_size == 0
  777. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  778. assert inner_dim % world_size == 0
  779. assert config.hidden_size % config.n_head == 0
  780. headdim = config.hidden_size // config.n_head
  781. # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
  782. # vocab_size // world_size coordinates are nonzero.
  783. def combine_word_embeddings(state_dicts, state_dict, key):
  784. dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
  785. state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
  786. def combine_dim(state_dicts, state_dict, key, dim=-1):
  787. if key in state_dict:
  788. state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
  789. def combine_qkv_headdim(state_dicts, state_dict, key):
  790. n_head = config.n_head
  791. n_head_kv = getattr(config, "n_head_kv", n_head)
  792. if key in state_dict:
  793. if n_head_kv == n_head:
  794. xs = [
  795. rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
  796. ]
  797. state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
  798. else:
  799. n_head_each_rank = [
  800. get_dim_for_local_rank(n_head, world_size, local_rank)
  801. for local_rank in range(world_size)
  802. ]
  803. n_head_kv_each_rank = [
  804. get_dim_for_local_rank(n_head_kv, world_size, local_rank)
  805. for local_rank in range(world_size)
  806. ]
  807. xs = [
  808. rearrange(
  809. s[key],
  810. "(nheadqkv headdim) ... -> nheadqkv headdim ...",
  811. nheadqkv=rank_n_head + 2 * rank_n_head_kv,
  812. headdim=headdim,
  813. )
  814. for s, rank_n_head, rank_n_head_kv in zip(
  815. state_dicts, n_head_each_rank, n_head_kv_each_rank
  816. )
  817. ]
  818. wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
  819. wk = torch.cat(
  820. [
  821. x[
  822. n_head_each_rank[rank] : n_head_each_rank[rank]
  823. + n_head_kv_each_rank[rank]
  824. ]
  825. for rank, x in enumerate(xs)
  826. ],
  827. dim=0,
  828. )
  829. wv = torch.cat(
  830. [
  831. x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]
  832. for rank, x in enumerate(xs)
  833. ],
  834. dim=0,
  835. )
  836. wqkv = torch.cat(
  837. [wq, wk, wv],
  838. dim=0,
  839. )
  840. state_dict[key] = rearrange(
  841. wqkv,
  842. "nheadqkv headdim ... -> (nheadqkv headdim) ...",
  843. )
  844. def combine_gated_mlp(state_dicts, state_dict, key):
  845. if key in state_dict:
  846. xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
  847. state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
  848. state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
  849. combine_word_embeddings(
  850. state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
  851. )
  852. if "lm_head.weight" in state_dict:
  853. combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
  854. if "transformer.embeddings.position_embeddings.weight" in state_dict:
  855. combine_dim(
  856. state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
  857. )
  858. mlp_combine_fn = (
  859. combine_gated_mlp
  860. if config.activation_function in ["glu", "swiglu", "geglu"]
  861. else partial(combine_dim, dim=0)
  862. )
  863. for i in range(config.num_hidden_layers):
  864. combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
  865. combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
  866. combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
  867. mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
  868. combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
  869. combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
  870. return state_dict
  871. def remap_state_dict_hf_gpt2(state_dict, config):
  872. # Word embedding and position embedding
  873. def key_mapping_pos_emb(key):
  874. return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
  875. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  876. word_embeddings = state_dict.pop("wte.weight")
  877. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  878. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  879. vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
  880. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  881. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  882. )
  883. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  884. # LayerNorm
  885. def key_mapping_ln(key):
  886. key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
  887. key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
  888. return key
  889. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  890. # MLP
  891. for d in range(config.num_hidden_layers):
  892. W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
  893. state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
  894. W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
  895. state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
  896. def key_mapping_mlp(key):
  897. key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
  898. key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
  899. return key
  900. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  901. # Attention
  902. for d in range(config.num_hidden_layers):
  903. state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias
  904. Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
  905. state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
  906. Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
  907. state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
  908. def key_mapping_attn(key):
  909. key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
  910. key = re.sub(
  911. r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
  912. )
  913. return key
  914. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  915. return state_dict
  916. def remap_state_dict_megatron(state_dict, config):
  917. def key_mapping_transformer(key):
  918. key = re.sub(r"^language_model.encoder.", "transformer.", key)
  919. key = re.sub(r"^language_model.", "transformer.", key)
  920. return key
  921. state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
  922. # Word embedding and position embedding
  923. def key_mapping_pos_emb(key):
  924. return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
  925. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  926. word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
  927. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  928. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  929. vocab_size = (
  930. math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
  931. )
  932. state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
  933. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  934. )
  935. state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
  936. # LayerNorm
  937. def key_mapping_ln(key):
  938. key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
  939. key = re.sub(
  940. r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
  941. r"transformer.layers.\1.norm1.\2",
  942. key,
  943. )
  944. key = re.sub(
  945. r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
  946. r"transformer.layers.\1.norm2.\2",
  947. key,
  948. )
  949. return key
  950. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  951. # MLP
  952. def key_mapping_mlp(key):
  953. key = re.sub(
  954. r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
  955. r"transformer.layers.\1.mlp.fc1.\2",
  956. key,
  957. )
  958. key = re.sub(
  959. r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
  960. r"transformer.layers.\1.mlp.fc2.\2",
  961. key,
  962. )
  963. return key
  964. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  965. # Attention
  966. def key_mapping_attn(key):
  967. key = re.sub(
  968. r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
  969. r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
  970. key,
  971. )
  972. key = re.sub(
  973. r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
  974. r"transformer.layers.\1.mixer.Wqkv.\2",
  975. key,
  976. )
  977. key = re.sub(
  978. r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
  979. r"transformer.layers.\1.mixer.out_proj.\2",
  980. key,
  981. )
  982. return key
  983. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  984. # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
  985. # while we store Wqkv as ((3 nheads headdim), hidden_dim)
  986. headdim = config.hidden_size // config.num_attention_heads
  987. for d in range(config.num_hidden_layers):
  988. Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
  989. state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
  990. Wqkv,
  991. "(nheads three headdim) ... -> (three nheads headdim) ...",
  992. three=3,
  993. headdim=headdim,
  994. )
  995. bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
  996. state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
  997. bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
  998. )
  999. return state_dict