test_gpt_neox.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) 2023, Tri Dao.
  2. import time
  3. import pytest
  4. import torch
  5. from flash_attn.models.gpt import GPTLMHeadModel
  6. from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
  7. from flash_attn.utils.pretrained import state_dict_from_pretrained
  8. from transformers import AutoTokenizer, GPTNeoXConfig
  9. from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
  10. @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
  11. def test_gptj_state_dict(model_name):
  12. config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
  13. pretrained_state_dict = remap_state_dict_hf_gpt_neox(
  14. state_dict_from_pretrained(model_name), config
  15. )
  16. model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
  17. state_dict = model.state_dict()
  18. assert state_dict.keys() == pretrained_state_dict.keys()
  19. for k in state_dict.keys():
  20. assert state_dict[k].shape == pretrained_state_dict[k].shape
  21. @pytest.mark.parametrize(
  22. "model_name",
  23. [
  24. "EleutherAI/pythia-1b",
  25. "EleutherAI/pythia-2.8b",
  26. "EleutherAI/gpt-neox-20b",
  27. "togethercomputer/RedPajama-INCITE-7B-Base",
  28. ],
  29. )
  30. def test_gpt_neox_optimized(model_name):
  31. """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
  32. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  33. forward pass in fp16, when compared to the HF forward pass in fp32.
  34. """
  35. dtype = torch.float16
  36. device = "cuda"
  37. config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
  38. config.use_flash_attn = True
  39. config.fused_bias_fc = True
  40. config.fused_mlp = config.activation_function in [
  41. "gelu_fast",
  42. "gelu_new",
  43. "gelu_approx",
  44. "gelu_pytorch_tanh",
  45. ]
  46. config.fused_dropout_add_ln = True
  47. config.residual_in_fp32 = True
  48. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  49. model.eval()
  50. torch.manual_seed(0)
  51. batch_size = 2
  52. max_seqlen = 256
  53. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  54. input_ids = torch.randint(
  55. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  56. )
  57. with torch.no_grad():
  58. out = model.transformer(input_ids)
  59. logits = model(input_ids).logits
  60. del model
  61. # Need at least 2 GPUs, otherwise we'll OOM for the 20B model
  62. # Without device_map, the model is loaded on the CPU, which is very slow
  63. model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
  64. model_ref.eval()
  65. with torch.no_grad():
  66. out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
  67. logits_ref = model_ref(input_ids).logits.to(device=device)
  68. del model_ref
  69. model_hf = GPTNeoXForCausalLM.from_pretrained(
  70. model_name, torch_dtype=dtype, device_map={"": device}
  71. )
  72. model_hf.eval()
  73. with torch.no_grad():
  74. out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
  75. logits_hf = model_hf(input_ids).logits
  76. del model_hf
  77. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  78. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  79. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  80. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  81. assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
  82. assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
  83. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  84. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  85. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  86. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  87. assert (logits - logits_ref).abs().max().item() < 2 * (
  88. logits_hf - logits_ref
  89. ).abs().max().item()
  90. assert (logits - logits_ref).abs().mean().item() < 2 * (
  91. logits_hf - logits_ref
  92. ).abs().mean().item()