test_mha_parallel.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py
  3. import math
  4. import pytest
  5. import torch
  6. import torch.nn.functional as F
  7. from apex.transformer import parallel_state, tensor_parallel
  8. from einops import rearrange
  9. from flash_attn.modules.mha import MHA, ParallelMHA
  10. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  11. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  12. # @pytest.mark.parametrize('dtype', [torch.float16])
  13. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  14. # @pytest.mark.parametrize('world_size', [2])
  15. @pytest.mark.parametrize("sequence_parallel", [True, False])
  16. # @pytest.mark.parametrize('sequence_parallel', [False])
  17. @pytest.mark.parametrize("head_dim", [64, 128])
  18. # @pytest.mark.parametrize('head_dim', [64])
  19. @pytest.mark.parametrize("embed_dim", [1024, 4096])
  20. # @pytest.mark.parametrize('embed_dim', [1024])
  21. def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
  22. assert embed_dim % head_dim == 0
  23. num_heads = embed_dim // head_dim
  24. assert num_heads % world_size == 0
  25. rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
  26. if not torch.distributed.is_initialized():
  27. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  28. device = f"cuda:{torch.distributed.get_rank()}"
  29. assert world_size <= torch.distributed.get_world_size()
  30. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  31. rank = parallel_state.get_tensor_model_parallel_rank()
  32. # set seed
  33. torch.random.manual_seed(0)
  34. batch_size = 2
  35. seqlen = 1024
  36. assert (batch_size * seqlen) % world_size == 0
  37. x_pt = torch.randn(
  38. batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
  39. )
  40. # We need to generate g here so that all processes get the same gradient,
  41. # as rank 0 will have an extra bias that changes the RNG.
  42. # If we don't divide by batch_size, the gradient gets a bit too large.
  43. g = torch.randn_like(x_pt) / 32
  44. if sequence_parallel:
  45. x = (
  46. tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
  47. .detach()
  48. .clone()
  49. .requires_grad_()
  50. )
  51. else:
  52. x = x_pt.detach().clone().requires_grad_()
  53. model_pt = MHA(
  54. embed_dim,
  55. num_heads,
  56. rotary_emb_dim=int(head_dim // 2),
  57. use_flash_attn=True,
  58. device=device,
  59. dtype=dtype,
  60. )
  61. partition_dim = embed_dim // world_size
  62. model = ParallelMHA(
  63. embed_dim,
  64. num_heads,
  65. parallel_state.get_tensor_model_parallel_group(),
  66. rotary_emb_dim=int(head_dim // 2),
  67. use_flash_attn=True,
  68. sequence_parallel=sequence_parallel,
  69. device=device,
  70. dtype=dtype,
  71. )
  72. with torch.no_grad():
  73. model.Wqkv.weight.copy_(
  74. rearrange(
  75. rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
  76. :, rank * partition_dim : (rank + 1) * partition_dim
  77. ],
  78. "three o i -> (three o) i",
  79. )
  80. )
  81. model.Wqkv.bias.copy_(
  82. rearrange(
  83. rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
  84. :, rank * partition_dim : (rank + 1) * partition_dim
  85. ],
  86. "three o -> (three o)",
  87. )
  88. )
  89. model.out_proj.weight.copy_(
  90. model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
  91. )
  92. if rank == 0:
  93. model.out_proj.bias.copy_(model_pt.out_proj.bias)
  94. out = model(x, seqlen=seqlen)
  95. out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
  96. partition_batch_dim = batch_size * seqlen // world_size
  97. assert torch.allclose(
  98. out,
  99. out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  100. if sequence_parallel
  101. else out_pt,
  102. rtol=rtol,
  103. atol=atol,
  104. )
  105. out_pt.backward(g)
  106. out.backward(
  107. g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
  108. )
  109. parallel_state.destroy_model_parallel()
  110. assert torch.allclose(
  111. x.grad,
  112. x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  113. if sequence_parallel
  114. else x_pt.grad,
  115. rtol=rtol,
  116. atol=atol / 100, # magnitude of x.grad is quite small
  117. )
  118. # The error for d_weight and d_bias is quite a bit higher
  119. assert torch.allclose(
  120. model.Wqkv.weight.grad,
  121. rearrange(
  122. rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
  123. :, rank * partition_dim : (rank + 1) * partition_dim
  124. ],
  125. "three o i -> (three o) i",
  126. ),
  127. rtol=rtol,
  128. atol=atol * 10,
  129. )
  130. assert torch.allclose(
  131. model.Wqkv.bias.grad,
  132. rearrange(
  133. rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
  134. :, rank * partition_dim : (rank + 1) * partition_dim
  135. ],
  136. "three o -> (three o)",
  137. ),
  138. rtol=rtol,
  139. atol=atol * 5,
  140. )
  141. assert torch.allclose(
  142. model.out_proj.weight.grad,
  143. model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
  144. rtol=rtol,
  145. atol=atol * 10,
  146. )
  147. if rank == 0:
  148. assert torch.allclose(
  149. model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
  150. )