test_bert.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import re
  2. from collections import OrderedDict
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from transformers import BertConfig
  8. from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
  9. from transformers.models.bert.modeling_bert import BertModel as BertModelHF
  10. from flash_attn.models.bert import (
  11. BertForPreTraining,
  12. BertModel,
  13. inv_remap_state_dict,
  14. remap_state_dict,
  15. )
  16. from flash_attn.utils.pretrained import state_dict_from_pretrained
  17. @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
  18. # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
  19. def test_bert_state_dict(model_name):
  20. config = BertConfig.from_pretrained(model_name)
  21. pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
  22. model = BertForPreTraining(config)
  23. state_dict = model.state_dict()
  24. assert state_dict.keys() == pretrained_state_dict.keys()
  25. for k in state_dict.keys():
  26. assert state_dict[k].shape == pretrained_state_dict[k].shape
  27. def get_hf_models(model_name, config, dtype):
  28. pretrained_state_dict = state_dict_from_pretrained(model_name)
  29. def key_mapping_ln_gamma_beta(key):
  30. key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
  31. key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
  32. return key
  33. pretrained_state_dict = OrderedDict(
  34. (key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
  35. )
  36. model_hf = BertForPreTrainingHF(config)
  37. # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
  38. # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
  39. model_hf.load_state_dict(pretrained_state_dict, strict=False)
  40. model_hf.cuda().to(dtype=dtype)
  41. return model_hf
  42. @pytest.mark.parametrize("model_name", ["bert-base-uncased"])
  43. def test_bert_non_optimized(model_name):
  44. """Check that our implementation of BERT (without any optimizations enabled) matches the
  45. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  46. forward pass in fp16, when compared to the HF forward pass in fp32.
  47. """
  48. dtype = torch.float16
  49. config = BertConfig.from_pretrained(model_name)
  50. model = BertForPreTraining.from_pretrained(model_name, config)
  51. model = model.cuda().to(dtype=dtype)
  52. model_ref = get_hf_models(model_name, config, torch.float32)
  53. model_hf = get_hf_models(model_name, config, dtype)
  54. model.eval()
  55. model_ref.eval()
  56. model_hf.eval()
  57. torch.manual_seed(0)
  58. batch_size = 4
  59. max_seqlen = 512
  60. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
  61. attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
  62. input_ids = torch.randint(
  63. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  64. )
  65. out = model.bert(input_ids, attention_mask=attention_mask)
  66. sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
  67. out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
  68. sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
  69. out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
  70. sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
  71. print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
  72. print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
  73. print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
  74. print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
  75. assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
  76. sequence_output_hf - sequence_output_ref
  77. ).abs().max().item()
  78. assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
  79. pooled_output_hf - pooled_output_ref
  80. ).abs().max().item()
  81. @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
  82. # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
  83. def test_bert_optimized(model_name):
  84. """Check that our implementation of BERT (with all optimizations enabled) matches the
  85. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  86. forward pass in fp16, when compared to the HF forward pass in fp32.
  87. """
  88. dtype = torch.float16
  89. config = BertConfig.from_pretrained(model_name)
  90. # Our implementation of fused_mlp assumes the activation is
  91. # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
  92. # If you just want "gelu", disable fused_mlp.
  93. config.hidden_act = "gelu_new"
  94. config.use_flash_attn = True
  95. config.fused_bias_fc = True
  96. config.fused_mlp = True
  97. config.fused_dropout_add_ln = True
  98. model = BertForPreTraining.from_pretrained(model_name, config)
  99. model = model.cuda().to(dtype=dtype)
  100. model_ref = get_hf_models(model_name, config, torch.float32)
  101. model_hf = get_hf_models(model_name, config, dtype)
  102. model.eval()
  103. model_ref.eval()
  104. model_hf.eval()
  105. torch.manual_seed(0)
  106. batch_size = 4
  107. max_seqlen = 512
  108. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
  109. attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
  110. input_ids = torch.randint(
  111. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  112. )
  113. out = model.bert(input_ids, attention_mask=attention_mask)
  114. sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
  115. out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
  116. sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
  117. # Need to zero out the padded tokens in the sequence before comparison.
  118. sequence_output_hf[~attention_mask, :] = 0.0
  119. out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
  120. sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
  121. sequence_output_ref[~attention_mask, :] = 0.0
  122. print(
  123. f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
  124. )
  125. print(
  126. f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
  127. )
  128. print(
  129. f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
  130. )
  131. print(
  132. f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
  133. )
  134. assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
  135. sequence_output_hf - sequence_output_ref
  136. ).abs().max().item()
  137. assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
  138. pooled_output_hf - pooled_output_ref
  139. ).abs().max().item()
  140. out = model(input_ids, attention_mask=attention_mask)
  141. prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
  142. # Need to zero out the padded tokens in the sequence before comparison.
  143. prediction_scores = prediction_scores.clone()
  144. prediction_scores[~attention_mask, :] = 0.0
  145. out_hf = model_hf(input_ids, attention_mask=attention_mask)
  146. prediction_scores_hf, seq_relationship_scores_hf = (
  147. out_hf.prediction_logits,
  148. out_hf.seq_relationship_logits,
  149. )
  150. prediction_scores_hf[~attention_mask, :] = 0.0
  151. out_ref = model_ref(input_ids, attention_mask=attention_mask)
  152. prediction_scores_ref, seq_relationship_scores_ref = (
  153. out_ref.prediction_logits,
  154. out_ref.seq_relationship_logits,
  155. )
  156. prediction_scores_ref[~attention_mask, :] = 0.0
  157. print(
  158. f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
  159. )
  160. print(
  161. f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
  162. )
  163. print(
  164. f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
  165. )
  166. print(
  167. f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
  168. )
  169. assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
  170. prediction_scores_hf - prediction_scores_ref
  171. ).abs().max().item()
  172. assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
  173. seq_relationship_scores_hf - seq_relationship_scores_ref
  174. ).abs().max().item()
  175. @pytest.mark.parametrize("last_layer_subset", [False, True])
  176. # @pytest.mark.parametrize('last_layer_subset', [True])
  177. @pytest.mark.parametrize("has_key_padding_mask", [True, False])
  178. # @pytest.mark.parametrize('has_key_padding_mask', [True])
  179. @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
  180. # @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
  181. def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
  182. """Check that our implementation of BERT (with all optimizations enabled) matches the
  183. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  184. forward pass in fp16, when compared to the HF forward pass in fp32.
  185. """
  186. dtype = torch.float16
  187. config = BertConfig.from_pretrained(model_name)
  188. # Our implementation of fused_mlp assumes the activation is
  189. # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
  190. # If you just want "gelu", disable fused_mlp.
  191. config.hidden_act = "gelu_new"
  192. config.use_flash_attn = True
  193. config.fused_bias_fc = True
  194. config.fused_mlp = True
  195. config.fused_dropout_add_ln = True
  196. config.dense_seq_output = True
  197. config.last_layer_subset = last_layer_subset
  198. config.use_xentropy = True
  199. model = BertForPreTraining.from_pretrained(model_name, config)
  200. model = model.cuda().to(dtype=dtype)
  201. model_ref = get_hf_models(model_name, config, torch.float32)
  202. model_hf = get_hf_models(model_name, config, dtype)
  203. model.eval()
  204. model_ref.eval()
  205. model_hf.eval()
  206. torch.manual_seed(0)
  207. batch_size = 4
  208. max_seqlen = 512
  209. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
  210. if has_key_padding_mask:
  211. attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
  212. else:
  213. attention_mask = None
  214. input_ids = torch.randint(
  215. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  216. )
  217. labels = torch.randint(
  218. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  219. )
  220. if attention_mask is not None:
  221. labels[~attention_mask] = 0
  222. labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
  223. masked_tokens_mask = labels.flatten() > 0
  224. next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
  225. out = model(
  226. input_ids,
  227. attention_mask=attention_mask,
  228. labels=labels,
  229. next_sentence_label=next_sequence_label,
  230. )
  231. prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
  232. out_hf = model_hf(
  233. input_ids,
  234. attention_mask=attention_mask,
  235. labels=labels,
  236. next_sentence_label=next_sequence_label,
  237. )
  238. prediction_scores_hf, seq_relationship_scores_hf = (
  239. out_hf.prediction_logits,
  240. out_hf.seq_relationship_logits,
  241. )
  242. prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
  243. out_ref = model_ref(
  244. input_ids,
  245. attention_mask=attention_mask,
  246. labels=labels,
  247. next_sentence_label=next_sequence_label,
  248. )
  249. prediction_scores_ref, seq_relationship_scores_ref = (
  250. out_ref.prediction_logits,
  251. out_ref.seq_relationship_logits,
  252. )
  253. prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]
  254. print(
  255. f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
  256. )
  257. print(
  258. f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
  259. )
  260. print(
  261. f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
  262. )
  263. print(
  264. f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
  265. )
  266. assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
  267. prediction_scores_hf - prediction_scores_ref
  268. ).abs().max().item()
  269. assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
  270. seq_relationship_scores_hf - seq_relationship_scores_ref
  271. ).abs().max().item()
  272. # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
  273. # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
  274. @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
  275. def test_inv_remap_state_dict(model_name: str):
  276. """
  277. Verify that we can convert a HF BERT model to flash_attn and back.
  278. """
  279. state_dict = state_dict_from_pretrained(model_name)
  280. config = BertConfig.from_pretrained(model_name)
  281. flash_state_dict = remap_state_dict(state_dict, config)
  282. recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)
  283. assert set(state_dict.keys()) == set(recovered_state_dict.keys())
  284. for k in state_dict.keys():
  285. assert state_dict[k].shape == recovered_state_dict[k].shape
  286. torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)