bert.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. # Copyright (c) 2022, Tri Dao.
  2. # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
  3. # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
  4. # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
  5. # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
  6. import re
  7. import logging
  8. from functools import partial
  9. from collections.abc import Sequence
  10. from collections import OrderedDict
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from transformers import BertConfig
  15. from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
  16. from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
  17. from einops import rearrange
  18. from flash_attn.modules.mha import MHA
  19. from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
  20. from flash_attn.modules.block import Block
  21. from flash_attn.modules.embedding import BertEmbeddings
  22. from flash_attn.bert_padding import unpad_input, pad_input
  23. from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
  24. try:
  25. from flash_attn.ops.fused_dense import FusedDense
  26. except ImportError:
  27. FusedDense = None
  28. try:
  29. from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
  30. except ImportError:
  31. dropout_add_layer_norm, layer_norm = None, None
  32. try:
  33. from flash_attn.losses.cross_entropy import CrossEntropyLoss
  34. except ImportError:
  35. CrossEntropyLoss = None
  36. logger = logging.getLogger(__name__)
  37. def create_mixer_cls(config, cross_attn=False, return_residual=False):
  38. use_flash_attn = getattr(config, 'use_flash_attn', False)
  39. fused_bias_fc = getattr(config, 'fused_bias_fc', False)
  40. mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
  41. dropout=config.attention_probs_dropout_prob, causal=False,
  42. fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
  43. return_residual=return_residual)
  44. return mixer_cls
  45. def create_mlp_cls(config, layer_idx=None, return_residual=False):
  46. inner_dim = config.intermediate_size
  47. fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
  48. if fused_dense_gelu_dense:
  49. assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
  50. 'supports approximate gelu')
  51. if not fused_dense_gelu_dense:
  52. approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
  53. mlp_cls = partial(Mlp, hidden_features=inner_dim,
  54. activation=partial(F.gelu, approximate=approximate),
  55. return_residual=return_residual)
  56. else:
  57. if FusedDenseGeluDense is None:
  58. raise ImportError('fused_dense is not installed')
  59. mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
  60. # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
  61. if isinstance(mlp_checkpoint_lvl, Sequence):
  62. assert layer_idx is not None
  63. mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
  64. mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
  65. checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
  66. return mlp_cls
  67. def create_block(config, layer_idx=None):
  68. last_layer_subset = getattr(config, 'last_layer_subset', False)
  69. cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
  70. # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
  71. # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
  72. # one layer) so we just choose not to return residual in this case.
  73. return_residual = not cross_attn
  74. mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
  75. mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
  76. norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
  77. block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
  78. prenorm=False, resid_dropout=config.hidden_dropout_prob,
  79. fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
  80. return_residual=return_residual)
  81. return block
  82. # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
  83. def _init_weights(module, initializer_range=0.02):
  84. if isinstance(module, nn.Linear):
  85. nn.init.normal_(module.weight, std=initializer_range)
  86. if module.bias is not None:
  87. nn.init.zeros_(module.bias)
  88. elif isinstance(module, nn.Embedding):
  89. nn.init.normal_(module.weight, std=initializer_range)
  90. if module.padding_idx is not None:
  91. nn.init.zeros_(module.weight[module.padding_idx])
  92. class BertEncoder(nn.Module):
  93. def __init__(self, config: BertConfig):
  94. super().__init__()
  95. self.use_flash_attn = getattr(config, 'use_flash_attn', False)
  96. self.layers = nn.ModuleList([create_block(config, layer_idx=i)
  97. for i in range(config.num_hidden_layers)])
  98. def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
  99. """If subset_mask is not None, we only want output for the subset of the sequence.
  100. This means that we only compute the last layer output for these tokens.
  101. subset_mask: (batch, seqlen), dtype=torch.bool
  102. """
  103. if key_padding_mask is None or not self.use_flash_attn:
  104. mixer_kwargs = ({'key_padding_mask': key_padding_mask}
  105. if key_padding_mask is not None else None)
  106. for layer in self.layers:
  107. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  108. if subset_mask is not None:
  109. hidden_states = hidden_states[subset_mask]
  110. else:
  111. batch, seqlen = hidden_states.shape[:2]
  112. hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
  113. hidden_states, key_padding_mask
  114. )
  115. mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
  116. if subset_mask is None:
  117. for layer in self.layers:
  118. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  119. hidden_states = pad_input(hidden_states, indices, batch, seqlen)
  120. else:
  121. for layer in self.layers[:-1]:
  122. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  123. if key_padding_mask is not None:
  124. subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
  125. subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
  126. subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
  127. dtype=torch.torch.int32), (1, 0))
  128. else:
  129. subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
  130. subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
  131. subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
  132. dtype=torch.torch.int32), (1, 0))
  133. hidden_states_subset, hidden_states = index_first_axis_residual(
  134. hidden_states, subset_idx
  135. )
  136. # It's ok to set max_seqlen_q to be much larger
  137. mixer_kwargs = {'x_kv': hidden_states,
  138. 'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
  139. 'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
  140. hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
  141. return hidden_states
  142. class BertPooler(nn.Module):
  143. def __init__(self, config):
  144. super().__init__()
  145. fused_bias_fc = getattr(config, 'fused_bias_fc', False)
  146. if fused_bias_fc and FusedDense is None:
  147. raise ImportError('fused_dense is not installed')
  148. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  149. self.dense = linear_cls(config.hidden_size, config.hidden_size)
  150. self.activation = nn.Tanh()
  151. def forward(self, hidden_states, pool=True):
  152. # We "pool" the model by simply taking the hidden state corresponding
  153. # to the first token.
  154. first_token_tensor = hidden_states[:, 0] if pool else hidden_states
  155. pooled_output = self.dense(first_token_tensor)
  156. pooled_output = self.activation(pooled_output)
  157. return pooled_output
  158. class BertPredictionHeadTransform(nn.Module):
  159. def __init__(self, config):
  160. super().__init__()
  161. fused_bias_fc = getattr(config, 'fused_bias_fc', False)
  162. if fused_bias_fc and FusedDense is None:
  163. raise ImportError('fused_dense is not installed')
  164. self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
  165. if self.fused_dropout_add_ln and layer_norm is None:
  166. raise ImportError('dropout_add_layer_norm is not installed')
  167. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  168. self.dense = linear_cls(config.hidden_size, config.hidden_size)
  169. approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
  170. self.transform_act_fn = nn.GELU(approximate=approximate)
  171. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  172. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  173. hidden_states = self.dense(hidden_states)
  174. hidden_states = self.transform_act_fn(hidden_states)
  175. if not self.fused_dropout_add_ln:
  176. hidden_states = self.layer_norm(hidden_states)
  177. else:
  178. hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
  179. self.layer_norm.eps)
  180. return hidden_states
  181. class BertLMPredictionHead(nn.Module):
  182. def __init__(self, config):
  183. super().__init__()
  184. fused_bias_fc = getattr(config, 'fused_bias_fc', False)
  185. if fused_bias_fc and FusedDense is None:
  186. raise ImportError('fused_dense is not installed')
  187. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  188. self.transform = BertPredictionHeadTransform(config)
  189. # The output weights are the same as the input embeddings, but there is
  190. # an output-only bias for each token.
  191. self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
  192. def forward(self, hidden_states):
  193. hidden_states = self.transform(hidden_states)
  194. hidden_states = self.decoder(hidden_states)
  195. return hidden_states
  196. class BertPreTrainingHeads(nn.Module):
  197. def __init__(self, config):
  198. super().__init__()
  199. self.predictions = BertLMPredictionHead(config)
  200. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  201. def forward(self, sequence_output, pooled_output):
  202. prediction_scores = self.predictions(sequence_output)
  203. seq_relationship_score = self.seq_relationship(pooled_output)
  204. return prediction_scores, seq_relationship_score
  205. class BertPreTrainedModel(nn.Module):
  206. """ An abstract class to handle weights initialization and
  207. a simple interface for dowloading and loading pretrained models.
  208. """
  209. def __init__(self, config, *inputs, **kwargs):
  210. super().__init__()
  211. if not isinstance(config, BertConfig):
  212. raise ValueError(
  213. "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
  214. "To create a model from a Google pretrained model use "
  215. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  216. self.__class__.__name__, self.__class__.__name__
  217. ))
  218. self.config = config
  219. @classmethod
  220. def from_pretrained(cls, model_name, config, *inputs, **kwargs):
  221. """
  222. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
  223. Download and cache the pre-trained model file if needed.
  224. Params:
  225. pretrained_model_name_or_path: either:
  226. - a path or url to a pretrained model archive containing:
  227. . `bert_config.json` a configuration file for the model
  228. . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
  229. - a path or url to a pretrained model archive containing:
  230. . `bert_config.json` a configuration file for the model
  231. . `model.chkpt` a TensorFlow checkpoint
  232. *inputs, **kwargs: additional input for the specific Bert class
  233. (ex: num_labels for BertForSequenceClassification)
  234. """
  235. # Instantiate model.
  236. model = cls(config, *inputs, **kwargs)
  237. load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
  238. config), strict=False)
  239. logger.info(load_return)
  240. return model
  241. class BertModel(BertPreTrainedModel):
  242. def __init__(self, config: BertConfig, add_pooling_layer=True):
  243. super().__init__(config)
  244. self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  245. if config.vocab_size % self.pad_vocab_size_multiple != 0:
  246. config.vocab_size += (self.pad_vocab_size_multiple
  247. - (config.vocab_size % self.pad_vocab_size_multiple))
  248. self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
  249. if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
  250. raise ImportError('dropout_add_layer_norm is not installed')
  251. assert config.position_embedding_type == 'absolute'
  252. assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
  253. self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
  254. config.max_position_embeddings, config.type_vocab_size,
  255. padding_idx=config.pad_token_id)
  256. self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
  257. self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  258. self.encoder = BertEncoder(config)
  259. self.pooler = BertPooler(config) if add_pooling_layer else None
  260. self.apply(partial(_init_weights, initializer_range=config.initializer_range))
  261. def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
  262. masked_tokens_mask=None):
  263. """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
  264. we only want the output for the masked tokens. This means that we only compute the last
  265. layer output for these tokens.
  266. masked_tokens_mask: (batch, seqlen), dtype=torch.bool
  267. """
  268. hidden_states = self.embeddings(input_ids, position_ids=position_ids,
  269. token_type_ids=token_type_ids)
  270. # TD [2022-12:18]: Don't need to force residual in fp32
  271. if not self.fused_dropout_add_ln:
  272. hidden_states = self.emb_drop(hidden_states)
  273. hidden_states = self.emb_ln(hidden_states)
  274. else:
  275. hidden_states = dropout_add_layer_norm(
  276. hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
  277. self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
  278. )
  279. if masked_tokens_mask is not None:
  280. batch_size, seqlen = input_ids.shape[:2]
  281. # We also need the first column for the CLS token
  282. first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
  283. device=input_ids.device)
  284. first_col_mask[:, 0] = True
  285. subset_mask = masked_tokens_mask | first_col_mask
  286. else:
  287. subset_mask = None
  288. sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
  289. subset_mask=subset_mask)
  290. if masked_tokens_mask is None:
  291. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  292. else:
  293. # TD [2022-03-01]: the indexing here is very tricky.
  294. if attention_mask is not None:
  295. subset_idx = subset_mask[attention_mask]
  296. pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
  297. sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
  298. else:
  299. pool_input = sequence_output[first_col_mask[subset_mask]]
  300. sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
  301. pooled_output = (self.pooler(pool_input, pool=False)
  302. if self.pooler is not None else None)
  303. return BaseModelOutputWithPoolingAndCrossAttentions(
  304. last_hidden_state=sequence_output,
  305. pooler_output=pooled_output,
  306. )
  307. class BertForPreTraining(BertPreTrainedModel):
  308. def __init__(self, config: BertConfig):
  309. super().__init__(config)
  310. # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
  311. # (around 15%) to the classifier heads.
  312. self.dense_seq_output = getattr(config, 'dense_seq_output', False)
  313. # If last_layer_subset, we only need the compute the last layer for a subset of tokens
  314. # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
  315. self.last_layer_subset = getattr(config, 'last_layer_subset', False)
  316. if self.last_layer_subset:
  317. assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
  318. use_xentropy = getattr(config, 'use_xentropy', False)
  319. if use_xentropy and CrossEntropyLoss is None:
  320. raise ImportError('xentropy_cuda is not installed')
  321. loss_cls = (nn.CrossEntropyLoss if not use_xentropy
  322. else partial(CrossEntropyLoss, inplace_backward=True))
  323. self.bert = BertModel(config)
  324. self.cls = BertPreTrainingHeads(config)
  325. self.mlm_loss = loss_cls(ignore_index=0)
  326. self.nsp_loss = loss_cls(ignore_index=-1)
  327. # Initialize weights and apply final processing
  328. self.apply(partial(_init_weights, initializer_range=config.initializer_range))
  329. self.tie_weights()
  330. def tie_weights(self):
  331. self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
  332. def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
  333. labels=None, next_sentence_label=None):
  334. """
  335. If labels are provided, they must be 0 for masked out tokens (as specified in the attention
  336. mask).
  337. Outputs:
  338. if `labels` and `next_sentence_label` are not `None`:
  339. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  340. sentence classification loss.
  341. if `labels` or `next_sentence_label` is `None`:
  342. Outputs a tuple comprising
  343. - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
  344. - the next sentence classification logits of shape [batch_size, 2].
  345. """
  346. masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
  347. outputs = self.bert(
  348. input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
  349. attention_mask=attention_mask.bool() if attention_mask is not None else None,
  350. masked_tokens_mask=masked_tokens_mask
  351. )
  352. sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
  353. if self.dense_seq_output and labels is not None:
  354. masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
  355. if not self.last_layer_subset:
  356. sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
  357. masked_token_idx)
  358. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  359. total_loss = None
  360. if labels is not None and next_sentence_label is not None:
  361. if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
  362. masked_lm_loss = self.mlm_loss(prediction_scores,
  363. labels.flatten()[masked_token_idx])
  364. else:
  365. masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
  366. rearrange(labels, '... -> (...)'))
  367. next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
  368. rearrange(next_sentence_label, '... -> (...)'))
  369. total_loss = masked_lm_loss.float() + next_sentence_loss.float()
  370. return BertForPreTrainingOutput(
  371. loss=total_loss,
  372. prediction_logits=prediction_scores,
  373. seq_relationship_logits=seq_relationship_score,
  374. )
  375. def state_dict_from_pretrained(model_name):
  376. from transformers.utils import WEIGHTS_NAME
  377. from transformers.utils.hub import cached_file
  378. return torch.load(cached_file(model_name, WEIGHTS_NAME))
  379. def remap_state_dict(state_dict, config):
  380. # LayerNorm
  381. def key_mapping_ln_gamma_beta(key):
  382. key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
  383. key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
  384. return key
  385. state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
  386. # Layers
  387. def key_mapping_layers(key):
  388. return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
  389. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  390. # LayerNorm
  391. def key_mapping_ln(key):
  392. key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
  393. key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
  394. r'bert.encoder.layers.\1.norm1.\2', key)
  395. key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
  396. r'bert.encoder.layers.\1.norm2.\2', key)
  397. key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
  398. r'cls.predictions.transform.layer_norm.\1', key)
  399. return key
  400. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  401. # MLP
  402. def key_mapping_mlp(key):
  403. key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
  404. r'bert.encoder.layers.\1.mlp.fc1.\2', key)
  405. key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
  406. r'bert.encoder.layers.\1.mlp.fc2.\2', key)
  407. return key
  408. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  409. # Attention
  410. last_layer_subset = getattr(config, 'last_layer_subset', False)
  411. for d in range(config.num_hidden_layers):
  412. Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
  413. Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
  414. Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
  415. bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
  416. bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
  417. bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
  418. if not (last_layer_subset and d == config.num_hidden_layers - 1):
  419. state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
  420. [Wq, Wk, Wv], dim=0
  421. )
  422. state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat(
  423. [bq, bk, bv], dim=0
  424. )
  425. else:
  426. state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
  427. state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
  428. [Wk, Wv], dim=0
  429. )
  430. state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
  431. state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat(
  432. [bk, bv], dim=0
  433. )
  434. def key_mapping_attn(key):
  435. return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
  436. r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
  437. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  438. def key_mapping_decoder_bias(key):
  439. return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
  440. state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
  441. # Word embedding
  442. pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
  443. if pad_vocab_size_multiple > 1:
  444. word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
  445. state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
  446. word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
  447. )
  448. decoder_weight = state_dict['cls.predictions.decoder.weight']
  449. state_dict['cls.predictions.decoder.weight'] = F.pad(
  450. decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
  451. )
  452. # If the vocab was padded, we want to set the decoder bias for those padded indices to be
  453. # strongly negative (i.e. the decoder shouldn't predict those indices).
  454. # TD [2022-05-09]: I don't think it affects the MLPerf training.
  455. decoder_bias = state_dict['cls.predictions.decoder.bias']
  456. state_dict['cls.predictions.decoder.bias'] = F.pad(
  457. decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
  458. )
  459. return state_dict