test_opt.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import re
  2. import time
  3. import pytest
  4. import torch
  5. from einops import rearrange
  6. from flash_attn.models.gpt import GPTLMHeadModel
  7. from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
  8. from flash_attn.utils.generation import update_graph_cache
  9. from flash_attn.utils.pretrained import state_dict_from_pretrained
  10. from transformers import AutoTokenizer, OPTConfig
  11. from transformers.models.opt.modeling_opt import OPTForCausalLM
  12. @pytest.mark.parametrize(
  13. "model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
  14. )
  15. # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
  16. def test_opt_state_dict(model_name):
  17. config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
  18. pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config)
  19. model = GPTLMHeadModel(config)
  20. state_dict = model.state_dict()
  21. assert state_dict.keys() == pretrained_state_dict.keys()
  22. for k in state_dict.keys():
  23. assert state_dict[k].shape == pretrained_state_dict[k].shape
  24. @pytest.mark.parametrize(
  25. "model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
  26. )
  27. # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
  28. def test_opt_optimized(model_name):
  29. """Check that our implementation of OPT (without all optimizations enabled) matches the
  30. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  31. forward pass in fp16, when compared to the HF forward pass in fp32.
  32. """
  33. dtype = torch.float16
  34. device = "cuda"
  35. config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
  36. config.use_flash_attn = True
  37. config.fused_bias_fc = True
  38. config.fused_mlp = True
  39. config.fused_dropout_add_ln = True
  40. # Only prenorm supports residual_in_fp32
  41. config.residual_in_fp32 = getattr(config, "prenorm", True)
  42. config.pad_vocab_size_multiple = 8
  43. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  44. model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
  45. model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
  46. model.eval()
  47. model_ref.eval()
  48. model_hf.eval()
  49. torch.manual_seed(0)
  50. batch_size = 2
  51. max_seqlen = 256
  52. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
  53. input_ids = torch.randint(
  54. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
  55. )
  56. if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512
  57. out = model.transformer(input_ids)
  58. out_hf = model_hf.model(input_ids).last_hidden_state
  59. out_ref = model_ref.model(input_ids).last_hidden_state
  60. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  61. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  62. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  63. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  64. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  65. logits = model(input_ids).logits
  66. logits_hf = model_hf(input_ids).logits
  67. logits_ref = model_ref(input_ids).logits
  68. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  69. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  70. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  71. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  72. assert (logits - logits_ref).abs().max().item() < 3 * (
  73. logits_hf - logits_ref
  74. ).abs().max().item()
  75. @pytest.mark.parametrize(
  76. "model_name",
  77. [
  78. "facebook/opt-125m",
  79. "facebook/opt-350m",
  80. "facebook/opt-1.3b",
  81. "facebook/opt-2.7b",
  82. "facebook/opt-6.7b",
  83. ],
  84. )
  85. # @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
  86. def test_opt_generation(model_name):
  87. """Check that our implementation of OPT generation matches the HF implementation:
  88. the scores in fp16 should be around the same as the HF scores in fp16, when compared to
  89. the HF scores in fp32.
  90. """
  91. print(f"\nMODEL: {model_name}")
  92. verbose = False
  93. dtype = torch.float16
  94. device = "cuda"
  95. rtol, atol = 3e-3, 3e-1
  96. config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
  97. # Only prenorm supports residual_in_fp32
  98. config.residual_in_fp32 = getattr(config, "prenorm", True)
  99. config.use_flash_attn = True
  100. config.fused_bias_fc = True
  101. config.fused_mlp = True
  102. config.fused_dropout_add_ln = True
  103. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  104. model.eval()
  105. torch.manual_seed(0)
  106. # OPT tokenizer requires use_fast=False
  107. # https://huggingface.co/docs/transformers/model_doc/opt
  108. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
  109. eos_token_id = tokenizer.eos_token_id
  110. input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
  111. device=device
  112. )
  113. max_length = 25
  114. # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
  115. # max_length = input_ids.shape[1] + 40
  116. # Slow generation for reference
  117. sequences = []
  118. scores = []
  119. cur_input_ids = input_ids
  120. with torch.inference_mode():
  121. scores.append(model(cur_input_ids).logits[:, -1])
  122. sequences.append(scores[-1].argmax(dim=-1))
  123. for _ in range(input_ids.shape[1] + 1, max_length):
  124. cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
  125. scores.append(model(cur_input_ids).logits[:, -1])
  126. sequences.append(scores[-1].argmax(dim=-1))
  127. if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
  128. break
  129. sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
  130. scores = tuple(scores)
  131. print("Without CUDA graph")
  132. torch.cuda.synchronize()
  133. start = time.time()
  134. out = model.generate(
  135. input_ids=input_ids,
  136. max_length=max_length,
  137. eos_token_id=eos_token_id,
  138. return_dict_in_generate=True,
  139. output_scores=True,
  140. enable_timing=True,
  141. )
  142. torch.cuda.synchronize()
  143. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  144. if verbose:
  145. print(out.sequences)
  146. print(tokenizer.batch_decode(out.sequences.tolist()))
  147. if getattr(config, "use_flash_attn", False):
  148. # Capture graph outside the timing loop
  149. batch_size, seqlen_og = input_ids.shape
  150. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  151. print("With CUDA graph")
  152. torch.cuda.synchronize()
  153. start = time.time()
  154. out_cg = model.generate(
  155. input_ids=input_ids,
  156. max_length=max_length,
  157. cg=True,
  158. return_dict_in_generate=True,
  159. output_scores=True,
  160. enable_timing=True,
  161. )
  162. torch.cuda.synchronize()
  163. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  164. if verbose:
  165. print(out_cg.sequences)
  166. print(tokenizer.batch_decode(out_cg.sequences.tolist()))
  167. del model
  168. model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
  169. model_hf.eval()
  170. print("HF fp16")
  171. torch.cuda.synchronize()
  172. start = time.time()
  173. out_hf = model_hf.generate(
  174. input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
  175. )
  176. torch.cuda.synchronize()
  177. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  178. del model_hf
  179. model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
  180. model_ref.eval()
  181. print("HF fp32")
  182. torch.cuda.synchronize()
  183. start = time.time()
  184. out_ref = model_ref.generate(
  185. input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
  186. )
  187. torch.cuda.synchronize()
  188. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  189. del model_ref
  190. print(tokenizer.batch_decode(out_ref.sequences.tolist()))
  191. if verbose:
  192. print(
  193. f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
  194. )
  195. print(
  196. f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
  197. )
  198. print(
  199. f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
  200. )
  201. print(
  202. f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
  203. )
  204. assert torch.all(out.sequences == sequences)
  205. assert torch.allclose(
  206. torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
  207. )
  208. assert torch.all(out.sequences == out_ref.sequences)
  209. assert torch.all(out.sequences == out_hf.sequences)
  210. assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
  211. torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
  212. ).abs().max().item()