gpt.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # Copyright (c) 2023, Tri Dao.
  2. import logging
  3. import math
  4. import re
  5. from functools import partial
  6. from collections import namedtuple, OrderedDict
  7. from collections.abc import Sequence
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from transformers import GPT2Config
  12. from einops import rearrange
  13. from flash_attn.modules.mha import MHA, ParallelMHA
  14. from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
  15. from flash_attn.modules.block import Block
  16. from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
  17. from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
  18. from flash_attn.utils.pretrained import state_dict_from_pretrained
  19. from flash_attn.utils.generation import GenerationMixin
  20. from flash_attn.models.opt import remap_state_dict_opt
  21. try:
  22. from flash_attn.ops.fused_dense import ColumnParallelLinear
  23. except ImportError:
  24. ColumnParallelLinear = None
  25. try:
  26. from flash_attn.ops.layer_norm import dropout_add_layer_norm
  27. except ImportError:
  28. dropout_add_layer_norm = None
  29. try:
  30. from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
  31. except ImportError:
  32. FusedDenseSqreluDense = None
  33. logger = logging.getLogger(__name__)
  34. def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
  35. factory_kwargs = {'device': device, 'dtype': dtype}
  36. head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
  37. softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
  38. if config.scale_attn_by_inverse_layer_idx:
  39. assert layer_idx is not None
  40. softmax_scale /= float(layer_idx + 1)
  41. dwconv = getattr(config, 'attn_dwconv', False)
  42. if dwconv:
  43. assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
  44. rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
  45. rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
  46. use_flash_attn = getattr(config, 'use_flash_attn', False)
  47. fused_bias_fc = getattr(config, 'fused_bias_fc', False)
  48. if not fused_bias_fc:
  49. assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
  50. mha_cls = MHA if process_group is None else ParallelMHA
  51. serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
  52. if process_group is None else {})
  53. parallel_kwargs = ({'process_group': process_group,
  54. 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
  55. if process_group is not None else {})
  56. mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
  57. softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
  58. rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
  59. use_flash_attn=use_flash_attn,
  60. **serial_kwargs, **parallel_kwargs, **factory_kwargs)
  61. return mixer_cls
  62. def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
  63. factory_kwargs = {'device': device, 'dtype': dtype}
  64. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  65. fused_mlp = getattr(config, 'fused_mlp', False)
  66. if fused_mlp:
  67. assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
  68. fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
  69. if fused_dense_sqrelu_dense:
  70. assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
  71. 'supports approximate activation_function sqrelu')
  72. assert not (fused_dense_sqrelu_dense and fused_mlp)
  73. if process_group is not None:
  74. assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
  75. if not fused_mlp and not fused_dense_sqrelu_dense:
  76. if config.activation_function == 'relu':
  77. activation = partial(F.relu, inplace=True)
  78. else:
  79. approximate = ('tanh' if config.activation_function
  80. in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
  81. activation=partial(F.gelu, approximate=approximate)
  82. mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
  83. else:
  84. mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
  85. # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
  86. if isinstance(mlp_checkpoint_lvl, Sequence):
  87. assert layer_idx is not None
  88. mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
  89. if fused_mlp:
  90. if FusedMLP is None:
  91. raise ImportError('fused_dense is not installed')
  92. activation = ('gelu_approx' if config.activation_function
  93. in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
  94. mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
  95. parallel_kwargs = ({'process_group': process_group,
  96. 'sequence_parallel': getattr(config, 'sequence_parallel', True)}
  97. if process_group is not None else {})
  98. mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
  99. checkpoint_lvl=mlp_checkpoint_lvl,
  100. **parallel_kwargs, **factory_kwargs)
  101. elif fused_dense_sqrelu_dense:
  102. assert FusedDenseSqreluDense is not None
  103. mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
  104. checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
  105. else:
  106. raise RuntimeError('MLP type not supported')
  107. return mlp_cls
  108. def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
  109. factory_kwargs = {'device': device, 'dtype': dtype}
  110. sequence_parallel = getattr(config, 'sequence_parallel', True)
  111. mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
  112. mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
  113. norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
  114. # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
  115. residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
  116. resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
  117. prenorm = getattr(config, 'prenorm', True)
  118. block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
  119. prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
  120. fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
  121. residual_in_fp32=residual_in_fp32,
  122. sequence_parallel=sequence_parallel and process_group is not None,
  123. mark_shared_params=process_group is not None)
  124. block.layer_idx = layer_idx
  125. return block
  126. class GPTPreTrainedModel(nn.Module):
  127. """ An abstract class to handle weights initialization and
  128. a simple interface for dowloading and loading pretrained models.
  129. """
  130. def __init__(self, config, *inputs, **kwargs):
  131. super().__init__()
  132. if not isinstance(config, GPT2Config):
  133. raise ValueError(
  134. "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
  135. "To create a model from a Google pretrained model use "
  136. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  137. self.__class__.__name__, self.__class__.__name__
  138. ))
  139. self.config = config
  140. @classmethod
  141. def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
  142. world_size=1, rank=0, **kwargs):
  143. """
  144. Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
  145. Download and cache the pre-trained model file if needed.
  146. """
  147. # Instantiate model.
  148. model = cls(config, *args, device=device, dtype=dtype, **kwargs)
  149. # If we're going to shard the model, then don't load fp32 weights to GPU.
  150. state_dict = state_dict_from_pretrained(
  151. model_name, device=device if world_size == 1 else None, dtype=dtype
  152. )
  153. if model_name.startswith('gpt2'):
  154. state_dict = remap_state_dict_gpt2(state_dict, config)
  155. elif model_name.startswith('facebook/opt'):
  156. state_dict = remap_state_dict_opt(state_dict, config)
  157. else:
  158. raise NotImplementedError(f'Model {model_name} not supported')
  159. if world_size > 1:
  160. state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
  161. state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
  162. load_return = model.load_state_dict(state_dict, strict=strict)
  163. logger.info(load_return)
  164. return model
  165. # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
  166. def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
  167. if isinstance(module, nn.Linear):
  168. nn.init.normal_(module.weight, std=initializer_range)
  169. if module.bias is not None:
  170. nn.init.zeros_(module.bias)
  171. elif isinstance(module, nn.Embedding):
  172. nn.init.normal_(module.weight, std=initializer_range)
  173. if rescale_prenorm_residual:
  174. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  175. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  176. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  177. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  178. #
  179. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  180. for name, p in module.named_parameters():
  181. if name in ["out_proj.weight", "fc2.weight"]:
  182. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  183. nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
  184. class GPTModel(GPTPreTrainedModel):
  185. def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
  186. super().__init__(config)
  187. factory_kwargs = {'device': device, 'dtype': dtype}
  188. self.process_group = process_group
  189. self.sequence_parallel = getattr(config, 'sequence_parallel', True)
  190. assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
  191. 'relu', 'sqrelu']
  192. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  193. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
  194. * pad_vocab_size_multiple)
  195. # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
  196. self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
  197. # These 2 options are for OPT-350m
  198. self.prenorm = getattr(config, 'prenorm', True)
  199. word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
  200. if process_group is None:
  201. self.embeddings = GPT2Embeddings(
  202. config.hidden_size, vocab_size, config.max_position_embeddings,
  203. word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
  204. )
  205. else:
  206. self.embeddings = ParallelGPT2Embeddings(
  207. config.hidden_size, vocab_size, config.max_position_embeddings,
  208. process_group=process_group, sequence_parallel=self.sequence_parallel,
  209. **factory_kwargs
  210. )
  211. # We change the order of dropout, residual and layer norm:
  212. # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
  213. # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
  214. # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
  215. # nn.Dropout probabilities are changed.
  216. # This is for performance reason: we can fuse dropout + add + layer_norm.
  217. self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
  218. **factory_kwargs)
  219. for i in range(config.num_hidden_layers)])
  220. self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
  221. if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
  222. raise ImportError('dropout_add_layer_norm is not installed')
  223. if self.prenorm:
  224. self.drop_f = nn.Dropout(config.resid_pdrop)
  225. self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
  226. **factory_kwargs)
  227. if process_group is not None:
  228. for p in self.ln_f.parameters():
  229. # Mark the norm parameters as "shared_params" so that we sync their values at init.
  230. p._shared_params = True
  231. # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
  232. if self.sequence_parallel:
  233. p._sequence_parallel = True
  234. self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
  235. initializer_range=config.initializer_range))
  236. self.tie_weights()
  237. def tie_weights(self):
  238. if self.process_group is not None:
  239. sync_shared_params(self, self.process_group)
  240. def forward(self, input_ids, position_ids=None, inference_params=None):
  241. # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
  242. # dimensions so that we can split on it easily, in case of small batch size.
  243. # Only the attention layers need to know the seqlen.
  244. embedding_kwargs = ({'combine_batch_seqlen_dim': True}
  245. if self.process_group is not None and self.sequence_parallel else {})
  246. hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
  247. residual = None
  248. mixer_kwargs = ({'seqlen': input_ids.shape[1]}
  249. if self.process_group is not None and self.sequence_parallel else {})
  250. if inference_params is not None:
  251. mixer_kwargs['inference_params'] = inference_params
  252. for layer in self.layers:
  253. if self.prenorm:
  254. hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
  255. else:
  256. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  257. if self.prenorm:
  258. if not self.fused_dropout_add_ln:
  259. dropped = self.drop_f(hidden_states)
  260. residual = (dropped + residual) if residual is not None else dropped
  261. hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
  262. else:
  263. # Set prenorm=False here since we don't need the residual
  264. hidden_states = dropout_add_layer_norm(
  265. hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
  266. self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
  267. residual_in_fp32=self.residual_in_fp32
  268. )
  269. return hidden_states
  270. class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
  271. def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
  272. factory_kwargs = {'device': device, 'dtype': dtype}
  273. super().__init__(config)
  274. self.process_group = process_group
  275. self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
  276. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  277. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
  278. * pad_vocab_size_multiple)
  279. # This option is for OPT-350m
  280. word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
  281. embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
  282. if word_embed_proj_dim is not None:
  283. self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
  284. else:
  285. self.project_out = None
  286. if process_group is None:
  287. self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs)
  288. else:
  289. if ColumnParallelLinear is None:
  290. raise ImportError('fused_dense_lib is not installed')
  291. self.lm_head = ColumnParallelLinear(
  292. embed_dim, vocab_size, process_group, bias=False,
  293. sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
  294. )
  295. # Initialize weights and apply final processing
  296. self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
  297. initializer_range=config.initializer_range))
  298. self.tie_weights()
  299. def tie_weights(self):
  300. self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
  301. if self.process_group is not None:
  302. sync_shared_params(self, self.process_group)
  303. def forward(self, input_ids, position_ids=None, inference_params=None):
  304. """
  305. inference_params: for generation. Adapted from Megatron-LM (and Apex)
  306. https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
  307. """
  308. hidden_states = self.transformer(input_ids, position_ids=position_ids,
  309. inference_params=inference_params)
  310. if self.project_out is not None:
  311. hidden_states = self.project_out(hidden_states)
  312. lm_logits = self.lm_head(hidden_states)
  313. # During inference, we want the full logit for sampling
  314. if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
  315. lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
  316. lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
  317. CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
  318. return CausalLMOutput(logits=lm_logits)
  319. def load_state_dict(self, state_dict, strict=True):
  320. # Remapping from our checkpoints that used a different ordering of layers in the block
  321. # Previous: Attn / MLP -> Dropout -> Add -> LN
  322. # Current: Dropout -> Add -> LN -> Attn / MLP
  323. if 'transformer.ln_0.weight' in state_dict:
  324. n_layers = len(self.transformer.layers)
  325. ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
  326. ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
  327. state_dict['transformer.ln_f.weight'] = ln_weight
  328. state_dict['transformer.ln_f.bias'] = ln_bias
  329. for l in reversed(range(n_layers)):
  330. ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
  331. ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
  332. state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
  333. state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
  334. if l > 0:
  335. ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
  336. ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
  337. state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
  338. state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
  339. ln_weight = state_dict.pop('transformer.ln_0.weight')
  340. ln_bias = state_dict.pop('transformer.ln_0.bias')
  341. state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
  342. state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
  343. return super().load_state_dict(state_dict, strict=strict)
  344. def remap_state_dict_gpt2(state_dict, config):
  345. # Word embedding and position embedding
  346. def key_mapping_pos_emb(key):
  347. return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
  348. state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
  349. word_embeddings = state_dict.pop('wte.weight')
  350. # It's possible that vocab_size is padded to be a multiple of 8, for example.
  351. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  352. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
  353. state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
  354. word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
  355. )
  356. state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
  357. # LayerNorm
  358. def key_mapping_ln(key):
  359. key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
  360. key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
  361. return key
  362. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  363. # MLP
  364. for d in range(config.num_hidden_layers):
  365. W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
  366. state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
  367. W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
  368. state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
  369. def key_mapping_mlp(key):
  370. key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
  371. key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
  372. return key
  373. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  374. # Attention
  375. for d in range(config.num_hidden_layers):
  376. state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias
  377. Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
  378. state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
  379. Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
  380. state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
  381. def key_mapping_attn(key):
  382. key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
  383. key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
  384. return key
  385. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  386. return state_dict
  387. def shard_state_dict_tp(state_dict, config, world_size, rank):
  388. """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
  389. with tensor parallel.
  390. """
  391. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  392. vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
  393. assert vocab_size % world_size == 0
  394. assert config.hidden_size % world_size == 0
  395. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
  396. assert inner_dim % world_size == 0
  397. def shard_first_dim(state_dict, key):
  398. x = state_dict[key]
  399. dim = x.shape[0] // world_size
  400. state_dict[key] = x[rank * dim:(rank + 1) * dim]
  401. def shard_last_dim(state_dict, key):
  402. x = state_dict[key]
  403. dim = x.shape[-1] // world_size
  404. state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
  405. def shard_qkv_headdim(state_dict, key):
  406. x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
  407. dim = x.shape[1] // world_size
  408. state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
  409. 'three d ... -> (three d) ...')
  410. shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
  411. if 'lm_head.weight' in state_dict:
  412. shard_first_dim(state_dict, 'lm_head.weight')
  413. if 'transformer.embeddings.position_embeddings.weight' in state_dict:
  414. shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
  415. for i in range(config.num_hidden_layers):
  416. shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
  417. shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
  418. shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
  419. if rank != 0:
  420. state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
  421. shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
  422. shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
  423. shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
  424. if rank != 0:
  425. state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
  426. return state_dict