1
0

bert.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. # Copyright (c) 2022, Tri Dao.
  2. # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
  3. # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
  4. # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
  5. # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
  6. import logging
  7. import re
  8. from collections import OrderedDict
  9. from collections.abc import Sequence
  10. from functools import partial
  11. from typing import Any, Mapping
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. from einops import rearrange
  16. from transformers import BertConfig, PretrainedConfig
  17. from transformers.models.bert.modeling_bert import (
  18. BaseModelOutputWithPoolingAndCrossAttentions,
  19. BertForPreTrainingOutput,
  20. )
  21. from flash_attn.bert_padding import (
  22. index_first_axis,
  23. index_first_axis_residual,
  24. pad_input,
  25. unpad_input,
  26. )
  27. from flash_attn.modules.block import Block
  28. from flash_attn.modules.embedding import BertEmbeddings
  29. from flash_attn.modules.mha import MHA
  30. from flash_attn.modules.mlp import FusedMLP, Mlp
  31. from flash_attn.utils.pretrained import state_dict_from_pretrained
  32. try:
  33. from flash_attn.ops.fused_dense import FusedDense
  34. except ImportError:
  35. FusedDense = None
  36. try:
  37. from flash_attn.ops.triton.layer_norm import layer_norm_fn
  38. except ImportError:
  39. layer_norm_fn = None
  40. try:
  41. from flash_attn.losses.cross_entropy import CrossEntropyLoss
  42. except ImportError:
  43. CrossEntropyLoss = None
  44. logger = logging.getLogger(__name__)
  45. def create_mixer_cls(config, cross_attn=False, return_residual=False):
  46. use_flash_attn = getattr(config, "use_flash_attn", False)
  47. fused_bias_fc = getattr(config, "fused_bias_fc", False)
  48. rotary_kwargs = {}
  49. if config.position_embedding_type == "rotary":
  50. rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
  51. rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
  52. rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
  53. rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
  54. mixer_cls = partial(
  55. MHA,
  56. num_heads=config.num_attention_heads,
  57. cross_attn=cross_attn,
  58. dropout=config.attention_probs_dropout_prob,
  59. causal=False,
  60. fused_bias_fc=fused_bias_fc,
  61. use_flash_attn=use_flash_attn,
  62. return_residual=return_residual,
  63. **rotary_kwargs,
  64. )
  65. return mixer_cls
  66. def create_mlp_cls(config, layer_idx=None, return_residual=False):
  67. inner_dim = config.intermediate_size
  68. fused_mlp = getattr(config, "fused_mlp", False)
  69. if fused_mlp:
  70. assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
  71. "fused_mlp only " "supports approximate gelu"
  72. )
  73. if not fused_mlp:
  74. approximate = (
  75. "tanh"
  76. if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
  77. else "none"
  78. )
  79. mlp_cls = partial(
  80. Mlp,
  81. hidden_features=inner_dim,
  82. activation=partial(F.gelu, approximate=approximate),
  83. return_residual=return_residual,
  84. )
  85. else:
  86. if FusedMLP is None:
  87. raise ImportError("fused_dense is not installed")
  88. mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
  89. # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
  90. if isinstance(mlp_checkpoint_lvl, Sequence):
  91. assert layer_idx is not None
  92. mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
  93. mlp_cls = partial(
  94. FusedMLP,
  95. hidden_features=inner_dim,
  96. checkpoint_lvl=mlp_checkpoint_lvl,
  97. return_residual=return_residual,
  98. )
  99. return mlp_cls
  100. def create_block(config, layer_idx=None):
  101. last_layer_subset = getattr(config, "last_layer_subset", False)
  102. cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
  103. # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
  104. # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
  105. # one layer) so we just choose not to return residual in this case.
  106. return_residual = not cross_attn
  107. mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
  108. mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
  109. norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
  110. block = Block(
  111. config.hidden_size,
  112. mixer_cls,
  113. mlp_cls,
  114. norm_cls=norm_cls,
  115. prenorm=False,
  116. resid_dropout1=config.hidden_dropout_prob,
  117. resid_dropout2=config.hidden_dropout_prob,
  118. fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
  119. return_residual=return_residual,
  120. )
  121. return block
  122. # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
  123. def _init_weights(module, initializer_range=0.02):
  124. if isinstance(module, nn.Linear):
  125. nn.init.normal_(module.weight, std=initializer_range)
  126. if module.bias is not None:
  127. nn.init.zeros_(module.bias)
  128. elif isinstance(module, nn.Embedding):
  129. nn.init.normal_(module.weight, std=initializer_range)
  130. if module.padding_idx is not None:
  131. nn.init.zeros_(module.weight[module.padding_idx])
  132. class BertEncoder(nn.Module):
  133. def __init__(self, config: BertConfig):
  134. super().__init__()
  135. self.use_flash_attn = getattr(config, "use_flash_attn", False)
  136. self.layers = nn.ModuleList(
  137. [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  138. )
  139. def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
  140. """If subset_mask is not None, we only want output for the subset of the sequence.
  141. This means that we only compute the last layer output for these tokens.
  142. subset_mask: (batch, seqlen), dtype=torch.bool
  143. """
  144. if key_padding_mask is None or not self.use_flash_attn:
  145. mixer_kwargs = (
  146. {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
  147. )
  148. for layer in self.layers:
  149. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  150. if subset_mask is not None:
  151. hidden_states = hidden_states[subset_mask]
  152. else:
  153. batch, seqlen = hidden_states.shape[:2]
  154. hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
  155. hidden_states, key_padding_mask
  156. )
  157. mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
  158. if subset_mask is None:
  159. for layer in self.layers:
  160. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  161. hidden_states = pad_input(hidden_states, indices, batch, seqlen)
  162. else:
  163. for layer in self.layers[:-1]:
  164. hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
  165. if key_padding_mask is not None:
  166. subset_idx = torch.nonzero(
  167. subset_mask[key_padding_mask], as_tuple=False
  168. ).flatten()
  169. subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
  170. subset_cu_seqlens = F.pad(
  171. torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
  172. )
  173. else:
  174. subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
  175. subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
  176. subset_cu_seqlens = F.pad(
  177. torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
  178. )
  179. hidden_states_subset, hidden_states = index_first_axis_residual(
  180. hidden_states, subset_idx
  181. )
  182. # It's ok to set max_seqlen_q to be much larger
  183. mixer_kwargs = {
  184. "x_kv": hidden_states,
  185. "cu_seqlens": subset_cu_seqlens,
  186. "max_seqlen": max_seqlen_in_batch,
  187. "cu_seqlens_k": cu_seqlens,
  188. "max_seqlen_k": max_seqlen_in_batch,
  189. }
  190. hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
  191. return hidden_states
  192. class BertPooler(nn.Module):
  193. def __init__(self, config):
  194. super().__init__()
  195. fused_bias_fc = getattr(config, "fused_bias_fc", False)
  196. if fused_bias_fc and FusedDense is None:
  197. raise ImportError("fused_dense is not installed")
  198. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  199. self.dense = linear_cls(config.hidden_size, config.hidden_size)
  200. self.activation = nn.Tanh()
  201. def forward(self, hidden_states, pool=True):
  202. # We "pool" the model by simply taking the hidden state corresponding
  203. # to the first token.
  204. first_token_tensor = hidden_states[:, 0] if pool else hidden_states
  205. pooled_output = self.dense(first_token_tensor)
  206. pooled_output = self.activation(pooled_output)
  207. return pooled_output
  208. class BertPredictionHeadTransform(nn.Module):
  209. def __init__(self, config):
  210. super().__init__()
  211. fused_bias_fc = getattr(config, "fused_bias_fc", False)
  212. if fused_bias_fc and FusedDense is None:
  213. raise ImportError("fused_dense is not installed")
  214. self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
  215. if self.fused_dropout_add_ln and layer_norm_fn is None:
  216. raise ImportError("Triton is not installed")
  217. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  218. self.dense = linear_cls(config.hidden_size, config.hidden_size)
  219. approximate = (
  220. "tanh"
  221. if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
  222. else "none"
  223. )
  224. self.transform_act_fn = nn.GELU(approximate=approximate)
  225. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  226. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  227. hidden_states = self.dense(hidden_states)
  228. hidden_states = self.transform_act_fn(hidden_states)
  229. if not self.fused_dropout_add_ln:
  230. hidden_states = self.layer_norm(hidden_states)
  231. else:
  232. hidden_states = layer_norm_fn(
  233. hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
  234. )
  235. return hidden_states
  236. class BertLMPredictionHead(nn.Module):
  237. def __init__(self, config):
  238. super().__init__()
  239. fused_bias_fc = getattr(config, "fused_bias_fc", False)
  240. if fused_bias_fc and FusedDense is None:
  241. raise ImportError("fused_dense is not installed")
  242. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  243. self.transform = BertPredictionHeadTransform(config)
  244. # The output weights are the same as the input embeddings, but there is
  245. # an output-only bias for each token.
  246. self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
  247. def forward(self, hidden_states):
  248. hidden_states = self.transform(hidden_states)
  249. hidden_states = self.decoder(hidden_states)
  250. return hidden_states
  251. class BertPreTrainingHeads(nn.Module):
  252. def __init__(self, config):
  253. super().__init__()
  254. self.predictions = BertLMPredictionHead(config)
  255. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  256. def forward(self, sequence_output, pooled_output):
  257. prediction_scores = self.predictions(sequence_output)
  258. seq_relationship_score = self.seq_relationship(pooled_output)
  259. return prediction_scores, seq_relationship_score
  260. class BertPreTrainedModel(nn.Module):
  261. """An abstract class to handle weights initialization and
  262. a simple interface for dowloading and loading pretrained models.
  263. """
  264. def __init__(self, config, *inputs, **kwargs):
  265. super().__init__()
  266. if not isinstance(config, BertConfig):
  267. raise ValueError(
  268. "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
  269. "To create a model from a Google pretrained model use "
  270. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  271. self.__class__.__name__, self.__class__.__name__
  272. )
  273. )
  274. self.config = config
  275. @classmethod
  276. def from_pretrained(cls, model_name, config, *inputs, **kwargs):
  277. """
  278. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
  279. Download and cache the pre-trained model file if needed.
  280. Params:
  281. pretrained_model_name_or_path: either:
  282. - a path or url to a pretrained model archive containing:
  283. . `bert_config.json` a configuration file for the model
  284. . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
  285. - a path or url to a pretrained model archive containing:
  286. . `bert_config.json` a configuration file for the model
  287. . `model.chkpt` a TensorFlow checkpoint
  288. *inputs, **kwargs: additional input for the specific Bert class
  289. (ex: num_labels for BertForSequenceClassification)
  290. """
  291. # Instantiate model.
  292. model = cls(config, *inputs, **kwargs)
  293. load_return = model.load_state_dict(
  294. remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
  295. )
  296. logger.info(load_return)
  297. return model
  298. class BertModel(BertPreTrainedModel):
  299. def __init__(self, config: BertConfig, add_pooling_layer=True):
  300. super().__init__(config)
  301. self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  302. if config.vocab_size % self.pad_vocab_size_multiple != 0:
  303. config.vocab_size += self.pad_vocab_size_multiple - (
  304. config.vocab_size % self.pad_vocab_size_multiple
  305. )
  306. self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
  307. if self.fused_dropout_add_ln and layer_norm_fn is None:
  308. raise ImportError("Triton is not installed")
  309. assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
  310. self.embeddings = BertEmbeddings(
  311. config.hidden_size,
  312. config.vocab_size,
  313. config.max_position_embeddings,
  314. config.type_vocab_size,
  315. padding_idx=config.pad_token_id,
  316. )
  317. self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
  318. self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  319. self.encoder = BertEncoder(config)
  320. self.pooler = BertPooler(config) if add_pooling_layer else None
  321. self.apply(partial(_init_weights, initializer_range=config.initializer_range))
  322. def forward(
  323. self,
  324. input_ids,
  325. position_ids=None,
  326. token_type_ids=None,
  327. attention_mask=None,
  328. masked_tokens_mask=None,
  329. ):
  330. """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
  331. we only want the output for the masked tokens. This means that we only compute the last
  332. layer output for these tokens.
  333. masked_tokens_mask: (batch, seqlen), dtype=torch.bool
  334. """
  335. hidden_states = self.embeddings(
  336. input_ids, position_ids=position_ids, token_type_ids=token_type_ids
  337. )
  338. # TD [2022-12:18]: Don't need to force residual in fp32
  339. # BERT puts embedding LayerNorm before embedding dropout.
  340. if not self.fused_dropout_add_ln:
  341. hidden_states = self.emb_ln(hidden_states)
  342. else:
  343. hidden_states = layer_norm_fn(
  344. hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
  345. )
  346. hidden_states = self.emb_drop(hidden_states)
  347. if masked_tokens_mask is not None:
  348. batch_size, seqlen = input_ids.shape[:2]
  349. # We also need the first column for the CLS token
  350. first_col_mask = torch.zeros(
  351. batch_size, seqlen, dtype=torch.bool, device=input_ids.device
  352. )
  353. first_col_mask[:, 0] = True
  354. subset_mask = masked_tokens_mask | first_col_mask
  355. else:
  356. subset_mask = None
  357. sequence_output = self.encoder(
  358. hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
  359. )
  360. if masked_tokens_mask is None:
  361. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  362. else:
  363. # TD [2022-03-01]: the indexing here is very tricky.
  364. if attention_mask is not None:
  365. subset_idx = subset_mask[attention_mask]
  366. pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
  367. sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
  368. else:
  369. pool_input = sequence_output[first_col_mask[subset_mask]]
  370. sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
  371. pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
  372. return BaseModelOutputWithPoolingAndCrossAttentions(
  373. last_hidden_state=sequence_output,
  374. pooler_output=pooled_output,
  375. )
  376. class BertForPreTraining(BertPreTrainedModel):
  377. def __init__(self, config: BertConfig):
  378. super().__init__(config)
  379. # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
  380. # (around 15%) to the classifier heads.
  381. self.dense_seq_output = getattr(config, "dense_seq_output", False)
  382. # If last_layer_subset, we only need the compute the last layer for a subset of tokens
  383. # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
  384. self.last_layer_subset = getattr(config, "last_layer_subset", False)
  385. if self.last_layer_subset:
  386. assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
  387. use_xentropy = getattr(config, "use_xentropy", False)
  388. if use_xentropy and CrossEntropyLoss is None:
  389. raise ImportError("xentropy_cuda is not installed")
  390. loss_cls = (
  391. nn.CrossEntropyLoss
  392. if not use_xentropy
  393. else partial(CrossEntropyLoss, inplace_backward=True)
  394. )
  395. self.bert = BertModel(config)
  396. self.cls = BertPreTrainingHeads(config)
  397. self.mlm_loss = loss_cls(ignore_index=0)
  398. self.nsp_loss = loss_cls(ignore_index=-1)
  399. # Initialize weights and apply final processing
  400. self.apply(partial(_init_weights, initializer_range=config.initializer_range))
  401. self.tie_weights()
  402. def tie_weights(self):
  403. self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
  404. def forward(
  405. self,
  406. input_ids,
  407. position_ids=None,
  408. token_type_ids=None,
  409. attention_mask=None,
  410. labels=None,
  411. next_sentence_label=None,
  412. ):
  413. """
  414. If labels are provided, they must be 0 for masked out tokens (as specified in the attention
  415. mask).
  416. Outputs:
  417. if `labels` and `next_sentence_label` are not `None`:
  418. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  419. sentence classification loss.
  420. if `labels` or `next_sentence_label` is `None`:
  421. Outputs a tuple comprising
  422. - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
  423. - the next sentence classification logits of shape [batch_size, 2].
  424. """
  425. masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
  426. outputs = self.bert(
  427. input_ids,
  428. position_ids=position_ids,
  429. token_type_ids=token_type_ids,
  430. attention_mask=attention_mask.bool() if attention_mask is not None else None,
  431. masked_tokens_mask=masked_tokens_mask,
  432. )
  433. sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
  434. if self.dense_seq_output and labels is not None:
  435. masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
  436. if not self.last_layer_subset:
  437. sequence_output = index_first_axis(
  438. rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
  439. )
  440. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  441. total_loss = None
  442. if labels is not None and next_sentence_label is not None:
  443. if (
  444. self.dense_seq_output and labels is not None
  445. ): # prediction_scores are already flattened
  446. masked_lm_loss = self.mlm_loss(
  447. prediction_scores, labels.flatten()[masked_token_idx]
  448. )
  449. else:
  450. masked_lm_loss = self.mlm_loss(
  451. rearrange(prediction_scores, "... v -> (...) v"),
  452. rearrange(labels, "... -> (...)"),
  453. )
  454. next_sentence_loss = self.nsp_loss(
  455. rearrange(seq_relationship_score, "... t -> (...) t"),
  456. rearrange(next_sentence_label, "... -> (...)"),
  457. )
  458. total_loss = masked_lm_loss.float() + next_sentence_loss.float()
  459. return BertForPreTrainingOutput(
  460. loss=total_loss,
  461. prediction_logits=prediction_scores,
  462. seq_relationship_logits=seq_relationship_score,
  463. )
  464. def remap_state_dict(state_dict, config: PretrainedConfig):
  465. """
  466. Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
  467. """
  468. # LayerNorm
  469. def key_mapping_ln_gamma_beta(key):
  470. key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
  471. key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
  472. return key
  473. state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
  474. # Layers
  475. def key_mapping_layers(key):
  476. return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
  477. state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
  478. # LayerNorm
  479. def key_mapping_ln(key):
  480. key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
  481. key = re.sub(
  482. r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
  483. r"bert.encoder.layers.\1.norm1.\2",
  484. key,
  485. )
  486. key = re.sub(
  487. r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
  488. r"bert.encoder.layers.\1.norm2.\2",
  489. key,
  490. )
  491. key = re.sub(
  492. r"^cls.predictions.transform.LayerNorm.(weight|bias)",
  493. r"cls.predictions.transform.layer_norm.\1",
  494. key,
  495. )
  496. return key
  497. state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
  498. # MLP
  499. def key_mapping_mlp(key):
  500. key = re.sub(
  501. r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
  502. r"bert.encoder.layers.\1.mlp.fc1.\2",
  503. key,
  504. )
  505. key = re.sub(
  506. r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
  507. r"bert.encoder.layers.\1.mlp.fc2.\2",
  508. key,
  509. )
  510. return key
  511. state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
  512. # Attention
  513. last_layer_subset = getattr(config, "last_layer_subset", False)
  514. for d in range(config.num_hidden_layers):
  515. Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
  516. Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
  517. Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
  518. bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
  519. bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
  520. bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
  521. if not (last_layer_subset and d == config.num_hidden_layers - 1):
  522. state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
  523. [Wq, Wk, Wv], dim=0
  524. )
  525. state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
  526. else:
  527. state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
  528. state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
  529. state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
  530. state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
  531. def key_mapping_attn(key):
  532. return re.sub(
  533. r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
  534. r"bert.encoder.layers.\1.mixer.out_proj.\2",
  535. key,
  536. )
  537. state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
  538. def key_mapping_decoder_bias(key):
  539. return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
  540. state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
  541. # Word embedding
  542. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  543. if pad_vocab_size_multiple > 1:
  544. word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
  545. state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
  546. word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
  547. )
  548. decoder_weight = state_dict["cls.predictions.decoder.weight"]
  549. state_dict["cls.predictions.decoder.weight"] = F.pad(
  550. decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
  551. )
  552. # If the vocab was padded, we want to set the decoder bias for those padded indices to be
  553. # strongly negative (i.e. the decoder shouldn't predict those indices).
  554. # TD [2022-05-09]: I don't think it affects the MLPerf training.
  555. decoder_bias = state_dict["cls.predictions.decoder.bias"]
  556. state_dict["cls.predictions.decoder.bias"] = F.pad(
  557. decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
  558. )
  559. return state_dict
  560. def inv_remap_state_dict(state_dict, config: PretrainedConfig):
  561. """
  562. Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
  563. This function is meant to be the inverse of remap_state_dict.
  564. """
  565. # Word embedding
  566. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
  567. if pad_vocab_size_multiple > 1:
  568. word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
  569. decoder_weight = state_dict["cls.predictions.decoder.weight"]
  570. decoder_bias = state_dict["cls.predictions.decoder.bias"]
  571. # unpad embeddings
  572. state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
  573. : config.orig_vocab_size, :
  574. ]
  575. state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
  576. state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
  577. for d in range(config.num_hidden_layers):
  578. last_layer_subset = getattr(config, "last_layer_subset", False)
  579. if not last_layer_subset or d != (config.num_hidden_layers - 1):
  580. Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
  581. Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
  582. state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
  583. : Wqkv_weights.shape[0] // 3, :
  584. ]
  585. state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
  586. Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
  587. ]
  588. state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
  589. 2 * Wqkv_weights.shape[0] // 3 :, :
  590. ]
  591. state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
  592. : Wqkv_biases.shape[0] // 3
  593. ]
  594. state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
  595. Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
  596. ]
  597. state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
  598. 2 * Wqkv_biases.shape[0] // 3 :
  599. ]
  600. else:
  601. Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
  602. Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
  603. Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
  604. Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
  605. state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
  606. state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
  607. : Wkv_weights.shape[0] // 2, :
  608. ]
  609. state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
  610. Wkv_weights.shape[0] // 2 :, :
  611. ]
  612. state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
  613. state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
  614. : Wkv_biases.shape[0] // 2
  615. ]
  616. state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
  617. Wkv_biases.shape[0] // 2 :
  618. ]
  619. def inv_key_mapping_ln(key):
  620. key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
  621. key = re.sub(
  622. r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
  623. r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
  624. key,
  625. )
  626. key = re.sub(
  627. r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
  628. r"bert.encoder.layers.\1.output.LayerNorm.\2",
  629. key,
  630. )
  631. key = re.sub(
  632. r"cls.predictions.transform.layer_norm.(weight|bias)",
  633. r"cls.predictions.transform.LayerNorm.\1",
  634. key,
  635. )
  636. return key
  637. def inv_key_mapping_ln_gamma_beta(key):
  638. key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
  639. key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
  640. return key
  641. def inv_key_mapping_layers(key):
  642. return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
  643. def inv_key_mapping_mlp(key):
  644. key = re.sub(
  645. r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
  646. r"bert.encoder.layer.\1.intermediate.dense.\2",
  647. key,
  648. )
  649. key = re.sub(
  650. r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
  651. r"bert.encoder.layer.\1.output.dense.\2",
  652. key,
  653. )
  654. return key
  655. def inv_key_mapping_attn(key):
  656. return re.sub(
  657. r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
  658. r"bert.encoder.layer.\1.attention.output.dense.\2",
  659. key,
  660. )
  661. def inv_key_mapping_decoder_bias(key):
  662. return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
  663. state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
  664. state_dict = OrderedDict(
  665. (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
  666. )
  667. state_dict = OrderedDict(
  668. (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
  669. )
  670. state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
  671. state_dict = OrderedDict(
  672. (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
  673. )
  674. state_dict = OrderedDict(
  675. (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
  676. )
  677. return state_dict