test_gptj.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) 2023, Tri Dao.
  2. import time
  3. import pytest
  4. import torch
  5. from flash_attn.models.gpt import GPTLMHeadModel
  6. from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
  7. from flash_attn.utils.generation import update_graph_cache
  8. from flash_attn.utils.pretrained import state_dict_from_pretrained
  9. from transformers import AutoTokenizer, GPTJConfig
  10. from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
  11. @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
  12. def test_gptj_state_dict(model_name):
  13. config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
  14. pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
  15. model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
  16. state_dict = model.state_dict()
  17. assert state_dict.keys() == pretrained_state_dict.keys()
  18. for k in state_dict.keys():
  19. assert state_dict[k].shape == pretrained_state_dict[k].shape
  20. @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"])
  21. def test_gptj_optimized(model_name):
  22. """Check that our implementation of GPT-J (with all optimizations enabled) matches the
  23. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  24. forward pass in fp16, when compared to the HF forward pass in fp32.
  25. """
  26. dtype = torch.float16
  27. device = "cuda"
  28. config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
  29. config.use_flash_attn = True # FlashAttention-2 supports headdim 256
  30. config.fused_bias_fc = True
  31. config.fused_mlp = True
  32. config.fused_dropout_add_ln = True
  33. config.residual_in_fp32 = True
  34. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  35. model.eval()
  36. torch.manual_seed(0)
  37. batch_size = 2
  38. max_seqlen = 256
  39. input_ids = torch.randint(
  40. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  41. )
  42. with torch.no_grad():
  43. out = model.transformer(input_ids)
  44. logits = model(input_ids).logits
  45. del model
  46. # Without device_map, the model is loaded on the CPU, which is very slow
  47. model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
  48. model_ref.eval()
  49. with torch.no_grad():
  50. out_ref = model_ref.transformer(input_ids).last_hidden_state
  51. logits_ref = model_ref(input_ids).logits
  52. del model_ref
  53. model_hf = GPTJForCausalLM.from_pretrained(
  54. model_name, torch_dtype=dtype, device_map={"": device}
  55. )
  56. model_hf.eval()
  57. out_hf = model_hf.transformer(input_ids).last_hidden_state
  58. logits_hf = model_hf(input_ids).logits
  59. del model_hf
  60. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  61. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  62. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  63. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  64. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  65. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  66. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  67. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  68. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  69. assert (logits - logits_ref).abs().max().item() < 3 * (
  70. logits_hf - logits_ref
  71. ).abs().max().item()
  72. @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
  73. def test_gptj_generation(model_name):
  74. """Check that our implementation of GPT-J (with all optimizations enabled) matches the
  75. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  76. forward pass in fp16, when compared to the HF forward pass in fp32.
  77. """
  78. dtype = torch.float16
  79. device = "cuda"
  80. config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
  81. config.use_flash_attn = True # FlashAttention-2 supports headdim 256
  82. config.fused_bias_fc = True
  83. config.fused_mlp = True
  84. config.fused_dropout_add_ln = True
  85. # Only prenorm supports residual_in_fp32
  86. config.residual_in_fp32 = True
  87. tokenizer = AutoTokenizer.from_pretrained(model_name)
  88. eos_token_id = tokenizer.eos_token_id
  89. torch.manual_seed(0)
  90. batch_size = 1
  91. seqlen = 100
  92. max_length = 150
  93. input_ids = torch.randint(
  94. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  95. )
  96. model_hf = GPTJForCausalLM.from_pretrained(
  97. model_name, torch_dtype=dtype, device_map={"": device}
  98. )
  99. model_hf.eval()
  100. print("HF fp16")
  101. torch.cuda.synchronize()
  102. start = time.time()
  103. out_hf = model_hf.generate(
  104. input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
  105. )
  106. torch.cuda.synchronize()
  107. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  108. del model_hf
  109. model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
  110. model_ref.eval()
  111. with torch.no_grad():
  112. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  113. del model_ref
  114. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  115. model.eval()
  116. print("Without CUDA graph")
  117. torch.cuda.synchronize()
  118. start = time.time()
  119. out = model.generate(
  120. input_ids=input_ids,
  121. max_length=max_length,
  122. eos_token_id=eos_token_id,
  123. return_dict_in_generate=True,
  124. output_scores=True,
  125. enable_timing=True,
  126. teacher_outputs=out_hf.sequences,
  127. )
  128. torch.cuda.synchronize()
  129. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  130. # Capture graph outside the timing loop
  131. batch_size, seqlen_og = input_ids.shape
  132. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  133. print("With CUDA graph")
  134. torch.cuda.synchronize()
  135. start = time.time()
  136. out_cg = model.generate(
  137. input_ids=input_ids,
  138. max_length=max_length,
  139. cg=True,
  140. return_dict_in_generate=True,
  141. output_scores=True,
  142. enable_timing=True,
  143. teacher_outputs=out_hf.sequences,
  144. )
  145. torch.cuda.synchronize()
  146. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  147. with torch.no_grad():
  148. logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  149. logits_hf = torch.stack(out_hf.scores, dim=1)
  150. logits = torch.stack(out.scores, dim=1)
  151. logits_cg = torch.stack(out_cg.scores, dim=1)
  152. del model
  153. hf_error = (logits_hf - logits_ref).abs().max().item()
  154. assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
  155. print(f"HF fp16 logits max diff: {hf_error}")
  156. print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
  157. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  158. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
  159. assert torch.equal(logits_cg, logits)