test_gpt_generation_parallel.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
  3. import os
  4. import re
  5. import pytest
  6. import torch
  7. from einops import rearrange
  8. from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
  9. from flash_attn.utils.distributed import all_gather_raw
  10. from flash_attn.utils.pretrained import state_dict_from_pretrained
  11. from transformers import GPT2Config, GPT2Tokenizer
  12. from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
  13. # @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
  14. @pytest.mark.parametrize("world_size", [2])
  15. @pytest.mark.parametrize('rotary', [False, True])
  16. # @pytest.mark.parametrize("rotary", [False])
  17. @pytest.mark.parametrize("model_name", ["gpt2"])
  18. def test_tensor_parallel(model_name, rotary, world_size):
  19. """Check that our implementation of GPT2 generation matches the HF implementation:
  20. the scores in fp16 should be around the same as the HF scores in fp16, when compared to
  21. the HF scores in fp32.
  22. """
  23. dtype = torch.float16
  24. rtol, atol = 3e-3, 3e-1
  25. config = GPT2Config.from_pretrained(model_name)
  26. if rotary:
  27. config.n_positions = 0
  28. config.rotary_emb_dim = 64
  29. config.residual_in_fp32 = True
  30. config.use_flash_attn = True
  31. config.fused_bias_fc = True
  32. config.fused_mlp = True
  33. config.fused_dropout_add_ln = True
  34. config.pad_vocab_size_multiple = 8 * world_size
  35. config.sequence_parallel = False # Need to set this to False for generation
  36. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
  37. if not torch.distributed.is_initialized():
  38. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  39. device = f"cuda:{torch.distributed.get_rank()}"
  40. assert world_size <= torch.distributed.get_world_size()
  41. # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
  42. # GPU0 and GPU1 and things would hang
  43. torch.cuda.set_device(device)
  44. from apex.transformer import parallel_state
  45. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  46. rank = parallel_state.get_tensor_model_parallel_rank()
  47. process_group = parallel_state.get_tensor_model_parallel_group()
  48. # if not rotary, we load the weight from HF but ignore the position embeddings.
  49. # The model would be nonsense but it doesn't matter for the test.
  50. model = GPTLMHeadModel.from_pretrained(
  51. model_name,
  52. config,
  53. strict=not rotary,
  54. device=device,
  55. dtype=dtype,
  56. process_group=process_group,
  57. world_size=world_size,
  58. rank=rank,
  59. )
  60. model.eval()
  61. if not rotary:
  62. model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
  63. model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
  64. model_ref.eval()
  65. model_hf.eval()
  66. torch.manual_seed(0)
  67. tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
  68. input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(
  69. device=device
  70. )
  71. max_length = 30
  72. # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
  73. # max_length = input_ids.shape[1] + 40
  74. # Slow generation for reference
  75. sequences = []
  76. scores = []
  77. cur_input_ids = input_ids
  78. with torch.inference_mode():
  79. logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
  80. logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
  81. ..., : config.vocab_size
  82. ]
  83. scores.append(logits)
  84. sequences.append(scores[-1].argmax(dim=-1))
  85. for _ in range(input_ids.shape[1] + 1, max_length):
  86. cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
  87. logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
  88. logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
  89. ..., : config.vocab_size
  90. ]
  91. scores.append(logits)
  92. sequences.append(scores[-1].argmax(dim=-1))
  93. sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
  94. scores = tuple(scores)
  95. print(sequences)
  96. out = model.generate(
  97. input_ids=input_ids,
  98. max_length=max_length,
  99. tensor_parallel=world_size,
  100. vocab_size=config.vocab_size,
  101. return_dict_in_generate=True,
  102. output_scores=True,
  103. enable_timing=True,
  104. )
  105. print(out.sequences)
  106. if getattr(config, "use_flash_attn", False):
  107. out_cg = model.generate(
  108. input_ids=input_ids,
  109. max_length=max_length,
  110. tensor_parallel=world_size,
  111. vocab_size=config.vocab_size,
  112. cg=True,
  113. return_dict_in_generate=True,
  114. output_scores=True,
  115. enable_timing=True,
  116. )
  117. print(out_cg.sequences)
  118. parallel_state.destroy_model_parallel()
  119. if not rotary:
  120. out_hf = model_hf.generate(
  121. input_ids=input_ids,
  122. max_length=max_length,
  123. return_dict_in_generate=True,
  124. output_scores=True,
  125. )
  126. out_ref = model_ref.generate(
  127. input_ids=input_ids,
  128. max_length=max_length,
  129. return_dict_in_generate=True,
  130. output_scores=True,
  131. )
  132. print(
  133. f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
  134. )
  135. print(
  136. f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
  137. )
  138. print(
  139. f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
  140. )
  141. print(
  142. f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
  143. )
  144. assert torch.all(out.sequences == sequences)
  145. assert torch.allclose(
  146. torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
  147. )
  148. assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
  149. if not rotary:
  150. assert torch.all(out.sequences == out_ref.sequences)
  151. assert torch.all(out.sequences == out_hf.sequences)
  152. assert (
  153. torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
  154. ).abs().max().item() < 3 * (
  155. torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
  156. ).abs().max().item()