test_bigcode.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import time
  2. import pytest
  3. import torch
  4. from transformers import AutoTokenizer, GPTBigCodeConfig
  5. from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
  6. from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode
  7. from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode
  8. from flash_attn.utils.generation import update_graph_cache
  9. from flash_attn.utils.pretrained import state_dict_from_pretrained
  10. @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
  11. def test_bigcode_state_dict(model_name):
  12. config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
  13. pretrained_state_dict = remap_state_dict_hf_bigcode(
  14. state_dict_from_pretrained(model_name), config
  15. )
  16. model = GPTLMHeadModel(config, device="meta")
  17. state_dict = model.state_dict()
  18. assert state_dict.keys() == pretrained_state_dict.keys()
  19. for k in state_dict.keys():
  20. assert state_dict[k].shape == pretrained_state_dict[k].shape
  21. @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
  22. def test_bigcode_optimized(model_name):
  23. """Check that our implementation of BigCode (with all optimizations enabled) matches the
  24. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  25. forward pass in fp16, when compared to the HF forward pass in fp32.
  26. """
  27. dtype = torch.float16
  28. device = "cuda"
  29. config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
  30. config.use_flash_attn = True # FlashAttention-2 supports headdim 256
  31. config.fused_bias_fc = True
  32. config.fused_mlp = True
  33. config.fused_dropout_add_ln = True
  34. config.residual_in_fp32 = True
  35. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  36. model.eval()
  37. torch.manual_seed(0)
  38. batch_size = 2
  39. max_seqlen = 256
  40. input_ids = torch.randint(
  41. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  42. )
  43. with torch.no_grad():
  44. out = model.transformer(input_ids)
  45. logits = model(input_ids).logits
  46. del model
  47. # Without device_map, the model is loaded on the CPU, which is very slow
  48. model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
  49. model_ref.eval()
  50. with torch.no_grad():
  51. out_ref = model_ref.transformer(input_ids).last_hidden_state
  52. logits_ref = model_ref(input_ids).logits
  53. del model_ref
  54. model_hf = GPTBigCodeForCausalLM.from_pretrained(
  55. model_name, torch_dtype=dtype, device_map={"": device}
  56. )
  57. model_hf.eval()
  58. out_hf = model_hf.transformer(input_ids).last_hidden_state
  59. logits_hf = model_hf(input_ids).logits
  60. del model_hf
  61. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  62. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  63. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  64. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  65. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  66. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  67. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  68. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  69. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  70. assert (logits - logits_ref).abs().max().item() < 3 * (
  71. logits_hf - logits_ref
  72. ).abs().max().item()
  73. @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
  74. def test_bigcode_generation(model_name):
  75. """Check that our implementation of BigCode (with all optimizations enabled) matches the
  76. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  77. forward pass in fp16, when compared to the HF forward pass in fp32.
  78. """
  79. dtype = torch.float16
  80. device = "cuda"
  81. config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
  82. config.use_flash_attn = True # FlashAttention-2 supports headdim 256
  83. config.fused_bias_fc = True
  84. config.fused_mlp = True
  85. config.fused_dropout_add_ln = True
  86. # Only prenorm supports residual_in_fp32
  87. config.residual_in_fp32 = True
  88. tokenizer = AutoTokenizer.from_pretrained(model_name)
  89. eos_token_id = tokenizer.eos_token_id
  90. torch.manual_seed(0)
  91. batch_size = 1
  92. seqlen = 100
  93. max_length = 150
  94. input_ids = torch.randint(
  95. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  96. )
  97. model_hf = GPTBigCodeForCausalLM.from_pretrained(
  98. model_name, torch_dtype=dtype, device_map={"": device}
  99. )
  100. model_hf.eval()
  101. print("HF fp16")
  102. torch.cuda.synchronize()
  103. start = time.time()
  104. out_hf = model_hf.generate(
  105. input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
  106. )
  107. torch.cuda.synchronize()
  108. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  109. del model_hf
  110. model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
  111. model_ref.eval()
  112. with torch.no_grad():
  113. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  114. del model_ref
  115. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  116. model.eval()
  117. print("Without CUDA graph")
  118. torch.cuda.synchronize()
  119. start = time.time()
  120. out = model.generate(
  121. input_ids=input_ids,
  122. max_length=max_length,
  123. eos_token_id=eos_token_id,
  124. return_dict_in_generate=True,
  125. output_scores=True,
  126. enable_timing=True,
  127. teacher_outputs=out_hf.sequences,
  128. )
  129. torch.cuda.synchronize()
  130. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  131. # Capture graph outside the timing loop
  132. batch_size, seqlen_og = input_ids.shape
  133. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  134. print("With CUDA graph")
  135. torch.cuda.synchronize()
  136. start = time.time()
  137. out_cg = model.generate(
  138. input_ids=input_ids,
  139. max_length=max_length,
  140. cg=True,
  141. return_dict_in_generate=True,
  142. output_scores=True,
  143. enable_timing=True,
  144. teacher_outputs=out_hf.sequences,
  145. )
  146. torch.cuda.synchronize()
  147. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  148. with torch.no_grad():
  149. logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  150. logits_hf = torch.stack(out_hf.scores, dim=1)
  151. logits = torch.stack(out.scores, dim=1)
  152. logits_cg = torch.stack(out_cg.scores, dim=1)
  153. del model
  154. hf_error = (logits_hf - logits_ref).abs().max().item()
  155. assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
  156. print(f"HF fp16 logits max diff: {hf_error}")
  157. print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
  158. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  159. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
  160. assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error
  161. @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
  162. def test_inv_remap_state_dict(model_name: str):
  163. """
  164. Verify that we can convert a HF BigCode model to flash_attn and back.
  165. """
  166. state_dict = state_dict_from_pretrained(model_name)
  167. config = GPTBigCodeConfig.from_pretrained(model_name)
  168. flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config)
  169. recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config)
  170. assert set(state_dict.keys()) == set(recovered_state_dict.keys())
  171. for k in state_dict.keys():
  172. assert state_dict[k].shape == recovered_state_dict[k].shape
  173. torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)