gpt.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. # Copyright (c) 2023, Tri Dao.
  2. import logging
  3. import math
  4. import re
  5. from functools import partial
  6. from collections import namedtuple, OrderedDict
  7. from collections.abc import Sequence
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from transformers import GPT2Config
  12. from einops import rearrange
  13. from flash_attn.modules.mha import MHA, ParallelMHA
  14. from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
  15. from flash_attn.modules.block import Block, ParallelBlock
  16. from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
  17. from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
  18. from flash_attn.utils.pretrained import state_dict_from_pretrained
  19. from flash_attn.utils.generation import GenerationMixin
  20. from flash_attn.models.opt import remap_state_dict_hf_opt
  21. from flash_attn.models.gptj import remap_state_dict_hf_gptj
  22. from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
  23. try:
  24. from flash_attn.ops.fused_dense import ColumnParallelLinear
  25. except ImportError:
  26. ColumnParallelLinear = None
  27. try:
  28. from flash_attn.ops.layer_norm import dropout_add_layer_norm
  29. except ImportError:
  30. dropout_add_layer_norm = None
  31. try:
  32. from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
  33. except ImportError:
  34. dropout_add_layer_norm_parallel_residual = None
  35. try:
  36. from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd
  37. except ImportError:
  38. FusedDenseSqreluDense = None
  39. sqrelu_fwd = None
  40. logger = logging.getLogger(__name__)
  41. def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
  42. factory_kwargs = {'device': device, 'dtype': dtype}
  43. head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
  44. softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
  45. if config.scale_attn_by_inverse_layer_idx:
  46. assert layer_idx is not None
  47. softmax_scale /= float(layer_idx + 1)
  48. dwconv = getattr(config, 'attn_dwconv', False)
  49. if dwconv:
  50. assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
  51. qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
  52. out_proj_bias = getattr(config, 'out_proj_bias', True)
  53. rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
  54. rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
  55. rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
  56. use_flash_attn = getattr(config, 'use_flash_attn', False)
  57. fused_bias_fc = getattr(config, 'fused_bias_fc', False)
  58. if not fused_bias_fc:
  59. assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
  60. mha_cls = MHA if process_group is None else ParallelMHA
  61. serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
  62. if process_group is None else {})
  63. parallel_kwargs = ({'process_group': process_group,
  64. 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
  65. if process_group is not None else {})
  66. mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
  67. qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
  68. dropout=config.attn_pdrop,
  69. softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
  70. rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
  71. rotary_emb_interleaved=rotary_emb_interleaved,
  72. use_flash_attn=use_flash_attn,
  73. **serial_kwargs, **parallel_kwargs, **factory_kwargs)
  74. return mixer_cls
  75. def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
  76. factory_kwargs = {'device': device, 'dtype': dtype}
  77. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  78. fused_mlp = getattr(config, 'fused_mlp', False)
  79. if fused_mlp:
  80. assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
  81. fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
  82. if fused_dense_sqrelu_dense:
  83. assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
  84. 'supports approximate activation_function sqrelu')
  85. assert not (fused_dense_sqrelu_dense and fused_mlp)
  86. if process_group is not None:
  87. assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
  88. if not fused_mlp and not fused_dense_sqrelu_dense:
  89. assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
  90. if config.activation_function == 'relu':
  91. activation = partial(F.relu, inplace=True)
  92. elif config.activation_function == 'sqrelu':
  93. assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented'
  94. activation = sqrelu_fwd
  95. else:
  96. approximate = ('tanh' if config.activation_function
  97. in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
  98. activation=partial(F.gelu, approximate=approximate)
  99. mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
  100. else:
  101. mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
  102. # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
  103. if isinstance(mlp_checkpoint_lvl, Sequence):
  104. assert layer_idx is not None
  105. mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
  106. if fused_mlp:
  107. if FusedMLP is None:
  108. raise ImportError('fused_dense is not installed')
  109. activation = ('gelu_approx' if config.activation_function
  110. in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
  111. mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
  112. parallel_kwargs = ({'process_group': process_group,
  113. 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
  114. if process_group is not None else {})
  115. mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
  116. checkpoint_lvl=mlp_checkpoint_lvl,
  117. **parallel_kwargs, **factory_kwargs)
  118. elif fused_dense_sqrelu_dense:
  119. assert FusedDenseSqreluDense is not None
  120. mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
  121. checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
  122. else:
  123. raise RuntimeError('MLP type not supported')
  124. return mlp_cls
  125. def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
  126. factory_kwargs = {'device': device, 'dtype': dtype}
  127. sequence_parallel = getattr(config, 'sequence_parallel', True)
  128. mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
  129. mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
  130. norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
  131. # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
  132. residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
  133. resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
  134. prenorm = getattr(config, 'prenorm', True)
  135. parallel_block = getattr(config, 'parallel_block', False)
  136. if not parallel_block:
  137. block = Block(
  138. config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
  139. prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
  140. fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
  141. residual_in_fp32=residual_in_fp32,
  142. sequence_parallel=sequence_parallel and process_group is not None,
  143. mark_shared_params=process_group is not None
  144. )
  145. else:
  146. assert prenorm
  147. block = ParallelBlock(
  148. config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
  149. resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
  150. tied_norm=getattr(config, 'parallel_block_tied_norm', False),
  151. fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
  152. residual_in_fp32=residual_in_fp32,
  153. sequence_parallel=sequence_parallel and process_group is not None,
  154. mark_shared_params=process_group is not None
  155. )
  156. block.layer_idx = layer_idx
  157. return block
  158. class GPTPreTrainedModel(nn.Module):
  159. """ An abstract class to handle weights initialization and
  160. a simple interface for dowloading and loading pretrained models.
  161. """
  162. def __init__(self, config, *inputs, **kwargs):
  163. super().__init__()
  164. if not isinstance(config, GPT2Config):
  165. raise ValueError(
  166. "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
  167. "To create a model from a Google pretrained model use "
  168. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  169. self.__class__.__name__, self.__class__.__name__
  170. ))
  171. self.config = config
  172. @classmethod
  173. def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
  174. world_size=1, rank=0, **kwargs):
  175. """
  176. Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
  177. Download and cache the pre-trained model file if needed.
  178. """
  179. # Instantiate model.
  180. model = cls(config, *args, device=device, dtype=dtype, **kwargs)
  181. # Load state_dict in cpu because we already initialized the model in GPU, and we don't
  182. # want extra stuff taking up more GPU memory
  183. state_dict = state_dict_from_pretrained(
  184. model_name, device='cpu', dtype=dtype
  185. )
  186. if model_name.startswith('gpt2'):
  187. state_dict = remap_state_dict_hf_gpt2(state_dict, config)
  188. elif model_name.startswith('facebook/opt'):
  189. state_dict = remap_state_dict_hf_opt(state_dict, config)
  190. elif model_name.startswith('EleutherAI/gpt-j-'):
  191. state_dict = remap_state_dict_hf_gptj(state_dict, config)
  192. strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
  193. elif model_name.startswith('EleutherAI/gpt-neox-'):
  194. state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
  195. else:
  196. raise NotImplementedError(f'Model {model_name} not supported')
  197. if world_size > 1:
  198. state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
  199. load_return = model.load_state_dict(state_dict, strict=strict)
  200. logger.info(load_return)
  201. return model
  202. # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
  203. def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
  204. if isinstance(module, nn.Linear):
  205. nn.init.normal_(module.weight, std=initializer_range)
  206. if module.bias is not None:
  207. nn.init.zeros_(module.bias)
  208. elif isinstance(module, nn.Embedding):
  209. nn.init.normal_(module.weight, std=initializer_range)
  210. if rescale_prenorm_residual:
  211. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  212. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  213. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  214. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  215. #
  216. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  217. for name, p in module.named_parameters():
  218. if name in ["out_proj.weight", "fc2.weight"]:
  219. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  220. nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
  221. class GPTModel(GPTPreTrainedModel):
  222. def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
  223. super().__init__(config)
  224. factory_kwargs = {'device': device, 'dtype': dtype}
  225. self.process_group = process_group
  226. self.sequence_parallel = getattr(config, 'sequence_parallel', True)
  227. assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
  228. 'relu', 'sqrelu']
  229. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  230. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
  231. * pad_vocab_size_multiple)
  232. # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
  233. self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
  234. # These 2 options are for OPT-350m
  235. self.prenorm = getattr(config, 'prenorm', True)
  236. word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
  237. # For GPT-J, GPT-NeoX
  238. self.parallel_block = getattr(config, 'parallel_block', False)
  239. if process_group is None:
  240. self.embeddings = GPT2Embeddings(
  241. config.hidden_size, vocab_size, config.max_position_embeddings,
  242. word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
  243. )
  244. else:
  245. self.embeddings = ParallelGPT2Embeddings(
  246. config.hidden_size, vocab_size, config.max_position_embeddings,
  247. process_group=process_group, sequence_parallel=self.sequence_parallel,
  248. **factory_kwargs
  249. )
  250. # We change the order of dropout, residual and layer norm:
  251. # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
  252. # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
  253. # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
  254. # nn.Dropout probabilities are changed.
  255. # This is for performance reason: we can fuse dropout + add + layer_norm.
  256. self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
  257. **factory_kwargs)
  258. for i in range(config.num_hidden_layers)])
  259. self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
  260. if self.fused_dropout_add_ln:
  261. if ((not self.parallel_block and dropout_add_layer_norm is None)
  262. or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
  263. raise ImportError('dropout_layer_norm is not installed')
  264. if self.prenorm:
  265. self.drop_f = nn.Dropout(config.resid_pdrop)
  266. self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
  267. **factory_kwargs)
  268. if process_group is not None:
  269. for p in self.ln_f.parameters():
  270. # Mark the norm parameters as "shared_params" so that we sync their values at init.
  271. p._shared_params = True
  272. # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
  273. if self.sequence_parallel:
  274. p._sequence_parallel = True
  275. self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
  276. initializer_range=config.initializer_range))
  277. self.tie_weights()
  278. def tie_weights(self):
  279. if self.process_group is not None:
  280. sync_shared_params(self, self.process_group)
  281. def forward(self, input_ids, position_ids=None, inference_params=None):
  282. # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
  283. # dimensions so that we can split on it easily, in case of small batch size.
  284. # Only the attention layers need to know the seqlen.
  285. embedding_kwargs = ({'combine_batch_seqlen_dim': True}
  286. if self.process_group is not None and self.sequence_parallel else {})
  287. hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
  288. if self.parallel_block:
  289. hidden_states2 = None
  290. residual = None
  291. mixer_kwargs = ({'seqlen': input_ids.shape[1]}
  292. if self.process_group is not None and self.sequence_parallel else {})
  293. if inference_params is not None:
  294. mixer_kwargs['inference_params'] = inference_params
  295. for layer in self.layers:
  296. if self.prenorm:
  297. if not self.parallel_block:
  298. hidden_states, residual = layer(hidden_states, residual,
  299. mixer_kwargs=mixer_kwargs)
  300. else:
  301. hidden_states, hidden_states2, residual = layer(
  302. hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
  303. )
  304. else:
  305. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  306. if self.prenorm:
  307. if not self.fused_dropout_add_ln:
  308. dropped = self.drop_f(hidden_states)
  309. if not self.parallel_block:
  310. residual = (dropped + residual) if residual is not None else dropped
  311. else:
  312. dropped2 = self.drop_f(hidden_states2)
  313. residual = ((residual + dropped + dropped2)
  314. if residual is not None else dropped + dropped2)
  315. hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
  316. else:
  317. # Set prenorm=False here since we don't need the residual
  318. if not self.parallel_block:
  319. hidden_states = dropout_add_layer_norm(
  320. hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
  321. self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
  322. residual_in_fp32=self.residual_in_fp32
  323. )
  324. else:
  325. hidden_states, _ = dropout_add_layer_norm_parallel_residual(
  326. hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
  327. None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
  328. prenorm=False, residual_in_fp32=self.residual_in_fp32
  329. )
  330. return hidden_states
  331. class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
  332. def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
  333. factory_kwargs = {'device': device, 'dtype': dtype}
  334. super().__init__(config)
  335. self.process_group = process_group
  336. self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  337. self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
  338. lm_head_bias = getattr(config, 'lm_head_bias', False)
  339. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  340. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
  341. * pad_vocab_size_multiple)
  342. # This option is for OPT-350m
  343. word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
  344. embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
  345. if word_embed_proj_dim is not None:
  346. self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
  347. else:
  348. self.project_out = None
  349. if process_group is None:
  350. self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
  351. else:
  352. if ColumnParallelLinear is None:
  353. raise ImportError('fused_dense_lib is not installed')
  354. self.lm_head = ColumnParallelLinear(
  355. embed_dim, vocab_size, process_group, bias=lm_head_bias,
  356. sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
  357. )
  358. # Initialize weights and apply final processing
  359. self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
  360. initializer_range=config.initializer_range))
  361. self.tie_weights()
  362. def tie_weights(self):
  363. if self.tie_word_embeddings:
  364. self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
  365. if self.process_group is not None:
  366. sync_shared_params(self, self.process_group)
  367. def forward(self, input_ids, position_ids=None, inference_params=None):
  368. """
  369. inference_params: for generation. Adapted from Megatron-LM (and Apex)
  370. https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
  371. """
  372. hidden_states = self.transformer(input_ids, position_ids=position_ids,
  373. inference_params=inference_params)
  374. if self.project_out is not None:
  375. hidden_states = self.project_out(hidden_states)
  376. lm_logits = self.lm_head(hidden_states)
  377. # During inference, we want the full logit for sampling
  378. if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
  379. lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
  380. lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
  381. CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
  382. return CausalLMOutput(logits=lm_logits)
  383. def load_state_dict(self, state_dict, strict=True):
  384. # Remapping from our checkpoints that used a different ordering of layers in the block
  385. # Previous: Attn / MLP -> Dropout -> Add -> LN
  386. # Current: Dropout -> Add -> LN -> Attn / MLP
  387. if 'transformer.ln_0.weight' in state_dict:
  388. n_layers = len(self.transformer.layers)
  389. ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
  390. ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
  391. state_dict['transformer.ln_f.weight'] = ln_weight
  392. state_dict['transformer.ln_f.bias'] = ln_bias
  393. for l in reversed(range(n_layers)):
  394. ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
  395. ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
  396. state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
  397. state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
  398. if l > 0:
  399. ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
  400. ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
  401. state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
  402. state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
  403. ln_weight = state_dict.pop('transformer.ln_0.weight')
  404. ln_bias = state_dict.pop('transformer.ln_0.bias')
  405. state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
  406. state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
  407. return super().load_state_dict(state_dict, strict=strict)
  408. def shard_state_dict_tp(state_dict, config, world_size, rank):
  409. """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
  410. with tensor parallel.
  411. """
  412. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  413. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
  414. assert vocab_size % world_size == 0
  415. assert config.hidden_size % world_size == 0
  416. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  417. assert inner_dim % world_size == 0
  418. def shard_first_dim(state_dict, key):
  419. x = state_dict[key]
  420. dim = x.shape[0] // world_size
  421. state_dict[key] = x[rank * dim:(rank + 1) * dim]
  422. def shard_last_dim(state_dict, key):
  423. x = state_dict[key]
  424. dim = x.shape[-1] // world_size
  425. state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
  426. def shard_qkv_headdim(state_dict, key):
  427. x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
  428. dim = x.shape[1] // world_size
  429. state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
  430. 'three d ... -> (three d) ...')
  431. shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
  432. if 'lm_head.weight' in state_dict:
  433. shard_first_dim(state_dict, 'lm_head.weight')
  434. if 'transformer.embeddings.position_embeddings.weight' in state_dict:
  435. shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
  436. for i in range(config.num_hidden_layers):
  437. shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
  438. shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
  439. shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
  440. if rank != 0:
  441. state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
  442. shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
  443. shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
  444. shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
  445. if rank != 0:
  446. state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
  447. return state_dict
  448. def combine_state_dicts_tp(state_dicts, config):
  449. """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
  450. with tensor parallel.
  451. """
  452. world_size = len(state_dicts)
  453. keys = state_dicts[0].keys()
  454. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  455. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
  456. assert vocab_size % world_size == 0
  457. assert config.hidden_size % world_size == 0
  458. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  459. assert inner_dim % world_size == 0
  460. # The word embeddings from Megatron are weird, for each shard only the first
  461. # vocab_size // world_size coordinates are nonzero.
  462. def combine_word_embeddings(state_dicts, state_dict, key):
  463. assert all(s[key].shape[0] == vocab_size for s in state_dicts)
  464. state_dict[key] = torch.cat([s[key][:vocab_size // world_size] for s in state_dicts], dim=0)
  465. def combine_dim(state_dicts, state_dict, key, dim=-1):
  466. state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
  467. def combine_qkv_headdim(state_dicts, state_dict, key):
  468. xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
  469. state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
  470. state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
  471. combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
  472. if 'lm_head.weight' in state_dict:
  473. combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
  474. if 'transformer.embeddings.position_embeddings.weight' in state_dict:
  475. combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
  476. for i in range(config.num_hidden_layers):
  477. combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
  478. combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
  479. combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
  480. combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight', 0)
  481. combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
  482. combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
  483. return state_dict
  484. def remap_state_dict_hf_gpt2(state_dict, config):
  485. # Word embedding and position embedding
  486. def key_mapping_pos_emb(key):
  487. return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
  488. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  489. word_embeddings = state_dict.pop('wte.weight')
  490. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  491. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  492. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
  493. state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
  494. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  495. )
  496. state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
  497. # LayerNorm
  498. def key_mapping_ln(key):
  499. key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
  500. key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
  501. return key
  502. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  503. # MLP
  504. for d in range(config.num_hidden_layers):
  505. W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
  506. state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
  507. W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
  508. state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
  509. def key_mapping_mlp(key):
  510. key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
  511. key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
  512. return key
  513. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  514. # Attention
  515. for d in range(config.num_hidden_layers):
  516. state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias
  517. Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
  518. state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
  519. Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
  520. state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
  521. def key_mapping_attn(key):
  522. key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
  523. key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
  524. return key
  525. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  526. return state_dict
  527. def remap_state_dict_megatron(state_dict, config):
  528. def key_mapping_transformer(key):
  529. key = re.sub(r'^language_model.encoder.', 'transformer.', key)
  530. key = re.sub(r'^language_model.', 'transformer.', key)
  531. return key
  532. state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
  533. # Word embedding and position embedding
  534. def key_mapping_pos_emb(key):
  535. return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
  536. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  537. word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
  538. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  539. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  540. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
  541. state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
  542. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  543. )
  544. state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
  545. # LayerNorm
  546. def key_mapping_ln(key):
  547. key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
  548. key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
  549. r'transformer.layers.\1.norm1.\2', key)
  550. key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
  551. r'transformer.layers.\1.norm2.\2', key)
  552. return key
  553. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  554. # MLP
  555. def key_mapping_mlp(key):
  556. key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
  557. r'transformer.layers.\1.mlp.fc1.\2', key)
  558. key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
  559. r'transformer.layers.\1.mlp.fc2.\2', key)
  560. return key
  561. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  562. # Attention
  563. def key_mapping_attn(key):
  564. key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
  565. r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
  566. key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
  567. r'transformer.layers.\1.mixer.Wqkv.\2', key)
  568. key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
  569. r'transformer.layers.\1.mixer.out_proj.\2', key)
  570. return key
  571. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  572. # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
  573. # while we store Wqkv as ((3 nheads headdim), hidden_dim)
  574. headdim = config.hidden_size // config.num_attention_heads
  575. for d in range(config.num_hidden_layers):
  576. Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
  577. state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
  578. Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
  579. three=3, headdim=headdim
  580. )
  581. bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
  582. state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
  583. bqkv, '(nheads three headdim) -> (three nheads headdim)',
  584. three=3, headdim=headdim
  585. )
  586. return state_dict