test_gpt.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. import re
  2. import pytest
  3. import torch
  4. from einops import rearrange
  5. from flash_attn.models.gpt import (
  6. GPTLMHeadModel,
  7. remap_state_dict_hf_gpt2,
  8. shard_state_dict_tp,
  9. combine_state_dicts_tp,
  10. )
  11. from flash_attn.utils.generation import InferenceParams
  12. from flash_attn.utils.pretrained import state_dict_from_pretrained
  13. from transformers import GPT2Config, GPT2Tokenizer
  14. from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
  15. @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
  16. # @pytest.mark.parametrize('model_name', ["gpt2"])
  17. def test_gpt2_state_dict(model_name):
  18. config = GPT2Config.from_pretrained(model_name)
  19. pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
  20. model = GPTLMHeadModel(config)
  21. state_dict = model.state_dict()
  22. assert state_dict.keys() == pretrained_state_dict.keys()
  23. for k in state_dict.keys():
  24. assert state_dict[k].shape == pretrained_state_dict[k].shape
  25. @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
  26. # @pytest.mark.parametrize('model_name', ["gpt2"])
  27. def test_gpt2_non_optimized(model_name):
  28. """Check that our implementation of GPT2 (without any optimizations enabled) matches the
  29. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  30. forward pass in fp16, when compared to the HF forward pass in fp32.
  31. """
  32. dtype = torch.float16
  33. config = GPT2Config.from_pretrained(model_name)
  34. model = GPTLMHeadModel.from_pretrained(model_name, config)
  35. model = model.cuda().to(dtype=dtype)
  36. model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
  37. model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
  38. model.eval()
  39. model_ref.eval()
  40. model_hf.eval()
  41. torch.manual_seed(0)
  42. batch_size = 4
  43. max_seqlen = 512
  44. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
  45. input_ids = torch.randint(
  46. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  47. )
  48. out = model.transformer(input_ids)
  49. out_hf = model_hf.transformer(input_ids).last_hidden_state
  50. out_ref = model_ref.transformer(input_ids).last_hidden_state
  51. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  52. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  53. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  54. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  55. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  56. logits = model(input_ids).logits
  57. logits_hf = model_hf(input_ids).logits
  58. logits_ref = model_ref(input_ids).logits
  59. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  60. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  61. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  62. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  63. assert (logits - logits_ref).abs().max().item() < 3 * (
  64. logits_hf - logits_ref
  65. ).abs().max().item()
  66. @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
  67. # @pytest.mark.parametrize('model_name', ["gpt2"])
  68. def test_gpt2_optimized(model_name):
  69. """Check that our implementation of GPT2 (with all optimizations enabled) matches the
  70. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  71. forward pass in fp16, when compared to the HF forward pass in fp32.
  72. """
  73. dtype = torch.float16
  74. config = GPT2Config.from_pretrained(model_name)
  75. vocab_size_og = config.vocab_size
  76. config.use_flash_attn = True
  77. config.fused_bias_fc = True
  78. config.fused_mlp = True
  79. config.fused_dropout_add_ln = True
  80. config.residual_in_fp32 = True
  81. config.pad_vocab_size_multiple = 8
  82. model = GPTLMHeadModel.from_pretrained(model_name, config)
  83. model = model.cuda().to(dtype=dtype)
  84. model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
  85. model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
  86. model.eval()
  87. model_ref.eval()
  88. model_hf.eval()
  89. torch.manual_seed(0)
  90. batch_size = 4
  91. max_seqlen = 512
  92. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
  93. input_ids = torch.randint(
  94. 0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  95. )
  96. out = model.transformer(input_ids)
  97. out_hf = model_hf.transformer(input_ids).last_hidden_state
  98. out_ref = model_ref.transformer(input_ids).last_hidden_state
  99. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  100. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  101. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  102. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  103. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  104. logits = model(input_ids).logits[..., :vocab_size_og]
  105. logits_hf = model_hf(input_ids).logits
  106. logits_ref = model_ref(input_ids).logits
  107. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  108. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  109. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  110. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  111. assert (logits - logits_ref).abs().max().item() < 3 * (
  112. logits_hf - logits_ref
  113. ).abs().max().item()
  114. @pytest.mark.parametrize("fused_ft_kernel", [False, True])
  115. # @pytest.mark.parametrize('fused_ft_kernel', [True])
  116. @pytest.mark.parametrize("optimized", [False, True])
  117. # @pytest.mark.parametrize('optimized', [True])
  118. @pytest.mark.parametrize("rotary", [False, True])
  119. # @pytest.mark.parametrize('rotary', [False])
  120. @pytest.mark.parametrize("model_name", ["gpt2"])
  121. def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
  122. """Check that our implementation of GPT2 generation matches the HF implementation:
  123. the scores in fp16 should be around the same as the HF scores in fp16, when compared to
  124. the HF scores in fp32.
  125. """
  126. dtype = torch.float16
  127. device = "cuda"
  128. rtol, atol = 3e-3, 3e-1
  129. config = GPT2Config.from_pretrained(model_name)
  130. if rotary:
  131. config.n_positions = 0
  132. config.rotary_emb_fraction = 0.5
  133. config.rotary_emb_base = 24000
  134. config.residual_in_fp32 = True
  135. if optimized:
  136. config.use_flash_attn = True
  137. config.fused_bias_fc = True
  138. config.fused_mlp = True
  139. config.fused_dropout_add_ln = True
  140. # if not rotary, we load the weight from HF but ignore the position embeddings.
  141. # The model would be nonsense but it doesn't matter for the test.
  142. model = GPTLMHeadModel.from_pretrained(
  143. model_name, config, strict=not rotary, device=device, dtype=dtype
  144. )
  145. model.eval()
  146. if not rotary:
  147. model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
  148. model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
  149. device=device
  150. )
  151. model_ref.eval()
  152. model_hf.eval()
  153. torch.manual_seed(0)
  154. tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
  155. input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
  156. device=device
  157. )
  158. max_length = 25
  159. # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
  160. # max_length = input_ids.shape[1] + 40
  161. # Slow generation for reference
  162. sequences = []
  163. scores = []
  164. cur_input_ids = input_ids
  165. with torch.inference_mode():
  166. scores.append(model(cur_input_ids).logits[:, -1])
  167. sequences.append(scores[-1].argmax(dim=-1))
  168. for _ in range(input_ids.shape[1] + 1, max_length):
  169. cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
  170. scores.append(model(cur_input_ids).logits[:, -1])
  171. sequences.append(scores[-1].argmax(dim=-1))
  172. sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
  173. scores = tuple(scores)
  174. out = model.generate(
  175. input_ids=input_ids,
  176. max_length=max_length,
  177. fused_ft_kernel=fused_ft_kernel,
  178. return_dict_in_generate=True,
  179. output_scores=True,
  180. enable_timing=True,
  181. )
  182. print(out.sequences)
  183. print(tokenizer.batch_decode(out.sequences.tolist()))
  184. if fused_ft_kernel or getattr(config, "use_flash_attn", False):
  185. out_cg = model.generate(
  186. input_ids=input_ids,
  187. max_length=max_length,
  188. fused_ft_kernel=fused_ft_kernel,
  189. cg=True,
  190. return_dict_in_generate=True,
  191. output_scores=True,
  192. enable_timing=True,
  193. )
  194. print(out_cg.sequences)
  195. assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
  196. if not rotary:
  197. out_hf = model_hf.generate(
  198. input_ids=input_ids,
  199. max_length=max_length,
  200. return_dict_in_generate=True,
  201. output_scores=True,
  202. )
  203. out_ref = model_ref.generate(
  204. input_ids=input_ids,
  205. max_length=max_length,
  206. return_dict_in_generate=True,
  207. output_scores=True,
  208. )
  209. print(
  210. f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
  211. )
  212. print(
  213. f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
  214. )
  215. print(
  216. f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
  217. )
  218. print(
  219. f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
  220. )
  221. print(tokenizer.batch_decode(out_ref.sequences.tolist()))
  222. assert torch.all(out.sequences == sequences)
  223. assert torch.allclose(
  224. torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
  225. )
  226. if not rotary:
  227. assert torch.all(out.sequences == out_ref.sequences)
  228. assert torch.all(out.sequences == out_hf.sequences)
  229. assert (
  230. torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
  231. ).abs().max().item() < 3 * (
  232. torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
  233. ).abs().max().item()
  234. def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
  235. out = model.generate(
  236. input_ids=input_ids,
  237. max_length=max_length,
  238. teacher_outputs=teacher_outputs,
  239. return_dict_in_generate=True,
  240. output_scores=True,
  241. enable_timing=True,
  242. **kwargs,
  243. )
  244. return torch.stack(out.scores, dim=1)
  245. @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
  246. # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
  247. @pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
  248. # @pytest.mark.parametrize('rotary', [None])
  249. @pytest.mark.parametrize("fused_ft_kernel", [False, True])
  250. # @pytest.mark.parametrize("fused_ft_kernel", [False])
  251. @pytest.mark.parametrize("model_name", ["gpt2"])
  252. def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen):
  253. """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
  254. dtype = torch.float16
  255. device = "cuda"
  256. rtol, atol = 3e-3, 3e-1
  257. config = GPT2Config.from_pretrained(model_name)
  258. config.n_positions = 16 * 1024
  259. assert seqlen <= maxlen <= config.n_positions
  260. if rotary is not None:
  261. config.n_positions = 0
  262. config.rotary_emb_dim = 32
  263. config.rotary_emb_interleaved = rotary == "interleaved"
  264. config.residual_in_fp32 = True
  265. config.use_flash_attn = True
  266. config.fused_bias_fc = True
  267. config.fused_mlp = True
  268. config.fused_dropout_add_ln = True
  269. model = GPTLMHeadModel(config, device=device, dtype=dtype)
  270. model.eval()
  271. torch.manual_seed(0)
  272. batch_size = 1
  273. input_ids = torch.randint(
  274. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  275. )
  276. teacher_outputs = torch.randint(
  277. 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
  278. )
  279. logits = get_logits(
  280. model, input_ids, maxlen, teacher_outputs=teacher_outputs, fused_ft_kernel=fused_ft_kernel
  281. )
  282. logits_cg = get_logits(
  283. model,
  284. input_ids,
  285. maxlen,
  286. teacher_outputs=teacher_outputs,
  287. fused_ft_kernel=fused_ft_kernel,
  288. cg=True,
  289. )
  290. assert torch.equal(logits, logits_cg)
  291. # Try increasing batch size and seqlen, then decrease them to see if it's still correct
  292. batch_size = 3
  293. maxlen += 30
  294. input_ids = torch.randint(
  295. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  296. )
  297. teacher_outputs = torch.randint(
  298. 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
  299. )
  300. logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
  301. logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
  302. assert torch.equal(logits, logits_cg)
  303. batch_size = 2
  304. maxlen -= 35
  305. input_ids = torch.randint(
  306. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  307. )
  308. teacher_outputs = torch.randint(
  309. 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
  310. )
  311. logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
  312. logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
  313. assert torch.equal(logits, logits_cg)
  314. @pytest.mark.parametrize("optimized", [False, True])
  315. # @pytest.mark.parametrize("optimized", [False])
  316. @pytest.mark.parametrize("model_name", ["gpt2"])
  317. def test_gpt2_multiple_token_generation(model_name, optimized):
  318. """Generation when we pass in multiple tokens at a time, not just one."""
  319. dtype = torch.float16
  320. device = "cuda"
  321. rtol, atol = 3e-3, 3e-1
  322. config = GPT2Config.from_pretrained(model_name)
  323. config.residual_in_fp32 = True
  324. if optimized:
  325. config.use_flash_attn = True
  326. config.fused_bias_fc = True
  327. config.fused_mlp = True
  328. config.fused_dropout_add_ln = True
  329. # fused_ft_kernel currently doesn't work with multiple tokens at a time
  330. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  331. model.eval()
  332. torch.manual_seed(0)
  333. input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device)
  334. # Reference logits
  335. logits_ref = model(input_ids).logits
  336. # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
  337. inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1)
  338. logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
  339. inference_params.sequence_len_offset += 10
  340. position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
  341. logits_1014 = model(
  342. input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
  343. ).logits
  344. inference_params.sequence_len_offset += 4
  345. position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
  346. logits_1420 = model(
  347. input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
  348. ).logits
  349. logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)
  350. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  351. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  352. assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
  353. @pytest.mark.parametrize("fused_ft_kernel, cg", [(False, False), (True, False), (True, True)])
  354. # @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
  355. # @pytest.mark.parametrize("optimized", [False, True])
  356. @pytest.mark.parametrize("optimized", [True])
  357. # @pytest.mark.parametrize("model_name", ["gpt2-medium"])
  358. @pytest.mark.parametrize("model_name", ["gpt2-xl"])
  359. def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
  360. dtype = torch.float16
  361. device = "cuda"
  362. rtol, atol = 3e-3, 3e-1
  363. config = GPT2Config.from_pretrained(model_name)
  364. config.residual_in_fp32 = True
  365. if optimized:
  366. config.use_flash_attn = True
  367. config.fused_bias_fc = True
  368. config.fused_mlp = True
  369. config.fused_dropout_add_ln = True
  370. config_draft = GPT2Config.from_pretrained("gpt2")
  371. config_draft.residual_in_fp32 = True
  372. if optimized:
  373. config_draft.use_flash_attn = True
  374. config_draft.fused_bias_fc = True
  375. config_draft.fused_mlp = True
  376. config_draft.fused_dropout_add_ln = True
  377. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  378. model.eval()
  379. model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype)
  380. model_draft.eval()
  381. torch.manual_seed(0)
  382. tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
  383. input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
  384. device=device
  385. )
  386. max_length = 100
  387. from flash_attn.utils.generation import decode_speculative
  388. torch.manual_seed(42)
  389. out = decode_speculative(
  390. input_ids,
  391. model,
  392. model_draft,
  393. max_length=max_length,
  394. top_k=5,
  395. fused_ft_kernel=fused_ft_kernel,
  396. cg=cg,
  397. speculative_lookahead=4,
  398. enable_timing=True,
  399. )
  400. print(tokenizer.batch_decode(out.sequences))
  401. out_og = model.generate(
  402. input_ids,
  403. max_length=max_length,
  404. top_k=5,
  405. fused_ft_kernel=fused_ft_kernel,
  406. cg=False,
  407. enable_timing=True,
  408. return_dict_in_generate=True,
  409. )
  410. print(tokenizer.batch_decode(out_og.sequences))
  411. @pytest.mark.parametrize(
  412. "n_heads_q_kv",
  413. [
  414. (8, 8), # Regular attention
  415. (8, 4), # GQA
  416. (8, 2), # MQA
  417. ],
  418. )
  419. def test_gpt2_shard_unshard(n_heads_q_kv):
  420. world_size = 2
  421. config = GPT2Config.from_pretrained("gpt2")
  422. config.vocab_size = 1024
  423. config.n_head, config.n_head_kv = n_heads_q_kv
  424. model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16)
  425. state_dict = model.state_dict()
  426. shards = [
  427. # NOTE: Shallow copy as `state_dict` is modified in-place
  428. shard_state_dict_tp(dict(state_dict), config, world_size, rank)
  429. for rank in range(world_size)
  430. ]
  431. state_dict2 = combine_state_dicts_tp(shards, config)
  432. assert state_dict2.keys() == state_dict.keys()
  433. for k in state_dict.keys():
  434. ref = state_dict[k]
  435. new = state_dict[k]
  436. assert torch.allclose(ref, new, atol=0.0, rtol=0.0)