test_gpt_neox.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import time
  2. import torch
  3. import pytest
  4. from transformers import GPTNeoXConfig, AutoTokenizer
  5. from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
  6. from flash_attn.models.gpt import GPTLMHeadModel
  7. from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
  8. from flash_attn.utils.pretrained import state_dict_from_pretrained
  9. from flash_attn.utils.generation import update_graph_cache
  10. @pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
  11. def test_gptj_state_dict(model_name):
  12. config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
  13. pretrained_state_dict = remap_state_dict_hf_gpt_neox(state_dict_from_pretrained(model_name), config)
  14. model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
  15. state_dict = model.state_dict()
  16. assert state_dict.keys() == pretrained_state_dict.keys()
  17. for k in state_dict.keys():
  18. assert state_dict[k].shape == pretrained_state_dict[k].shape
  19. @pytest.mark.parametrize('model_name', ["EleutherAI/gpt-neox-20b"])
  20. def test_gpt_neox_optimized(model_name):
  21. """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
  22. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  23. forward pass in fp16, when compared to the HF forward pass in fp32.
  24. """
  25. dtype = torch.float16
  26. device = 'cuda'
  27. config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
  28. config.use_flash_attn = True
  29. config.fused_bias_fc = True
  30. config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
  31. config.fused_dropout_add_ln = True
  32. config.residual_in_fp32 = True
  33. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  34. model.eval()
  35. torch.manual_seed(0)
  36. batch_size = 2
  37. max_seqlen = 256
  38. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  39. input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
  40. device=device)
  41. with torch.no_grad():
  42. out = model.transformer(input_ids)
  43. logits = model(input_ids).logits
  44. del model
  45. # Need at least 2 GPUs, otherwise we'll OOM
  46. # Without device_map, the model is loaded on the CPU, which is very slow
  47. model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map='auto')
  48. model_ref.eval()
  49. with torch.no_grad():
  50. out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
  51. logits_ref = model_ref(input_ids).logits.to(device=device)
  52. del model_ref
  53. model_hf = GPTNeoXForCausalLM.from_pretrained(model_name, torch_dtype=dtype,
  54. device_map={"": device})
  55. model_hf.eval()
  56. with torch.no_grad():
  57. out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
  58. logits_hf = model_hf(input_ids).logits
  59. del model_hf
  60. print(f'Output max diff: {(out - out_ref).abs().max().item()}')
  61. print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
  62. print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
  63. print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
  64. assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
  65. assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
  66. print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
  67. print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
  68. print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
  69. print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
  70. assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
  71. assert (logits - logits_ref).abs().mean().item() < 2 * (logits_hf - logits_ref).abs().mean().item()