test_mlp_parallel.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
  3. import pytest
  4. import torch
  5. import torch.nn.functional as F
  6. from apex.transformer import parallel_state, tensor_parallel
  7. from einops import rearrange
  8. from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
  9. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  10. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  11. # @pytest.mark.parametrize('dtype', [torch.float16])
  12. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  13. # @pytest.mark.parametrize('world_size', [2])
  14. @pytest.mark.parametrize("sequence_parallel", [True, False])
  15. # @pytest.mark.parametrize('sequence_parallel', [False])
  16. @pytest.mark.parametrize("activation", [F.silu, F.sigmoid])
  17. # @pytest.mark.parametrize('activation', [F.silu])
  18. @pytest.mark.parametrize("dim", [1024, 4096])
  19. # @pytest.mark.parametrize('dim', [1024])
  20. def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
  21. rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
  22. if not torch.distributed.is_initialized():
  23. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  24. device = f"cuda:{torch.distributed.get_rank()}"
  25. assert world_size <= torch.distributed.get_world_size()
  26. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  27. rank = parallel_state.get_tensor_model_parallel_rank()
  28. # set seed
  29. torch.random.manual_seed(0)
  30. batch_size = 2
  31. seqlen = 1024
  32. assert (batch_size * seqlen) % world_size == 0
  33. x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
  34. # We need to generate g here so that all processes get the same gradient,
  35. # as rank 0 will have an extra bias that changes the RNG.
  36. # If we don't divide by batch_size, the gradient gets a bit too large.
  37. g = torch.randn_like(x_pt) / 32
  38. if sequence_parallel:
  39. x = (
  40. tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
  41. .detach()
  42. .clone()
  43. .requires_grad_()
  44. )
  45. else:
  46. x = x_pt.detach().clone().requires_grad_()
  47. model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
  48. partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
  49. model = ParallelGatedMlp(
  50. dim,
  51. parallel_state.get_tensor_model_parallel_group(),
  52. activation=activation,
  53. sequence_parallel=sequence_parallel,
  54. device=device,
  55. dtype=dtype,
  56. )
  57. with torch.no_grad():
  58. model.fc1.weight.copy_(
  59. rearrange(
  60. rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[
  61. :, rank * partition_dim : (rank + 1) * partition_dim
  62. ],
  63. "two o i -> (two o) i",
  64. )
  65. )
  66. model.fc1.bias.copy_(
  67. rearrange(
  68. rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[
  69. :, rank * partition_dim : (rank + 1) * partition_dim
  70. ],
  71. "two o -> (two o)",
  72. )
  73. )
  74. model.fc2.weight.copy_(
  75. model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
  76. )
  77. if rank == 0:
  78. model.fc2.bias.copy_(model_pt.fc2.bias)
  79. out = model(x)
  80. out_pt = model_pt(x_pt)
  81. partition_batch_dim = batch_size * seqlen // world_size
  82. assert torch.allclose(
  83. out,
  84. out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  85. if sequence_parallel
  86. else out_pt,
  87. rtol=rtol,
  88. atol=atol,
  89. )
  90. out_pt.backward(g)
  91. out.backward(
  92. g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
  93. )
  94. parallel_state.destroy_model_parallel()
  95. assert torch.allclose(
  96. x.grad,
  97. x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  98. if sequence_parallel
  99. else x_pt.grad,
  100. rtol=rtol,
  101. atol=atol,
  102. )
  103. assert torch.allclose(
  104. model.fc1.weight.grad,
  105. rearrange(
  106. rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[
  107. :, rank * partition_dim : (rank + 1) * partition_dim
  108. ],
  109. "two o i -> (two o) i",
  110. ),
  111. rtol=rtol,
  112. atol=atol,
  113. )
  114. assert torch.allclose(
  115. model.fc1.bias.grad,
  116. rearrange(
  117. rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[
  118. :, rank * partition_dim : (rank + 1) * partition_dim
  119. ],
  120. "two o -> (two o)",
  121. ),
  122. rtol=rtol,
  123. atol=atol,
  124. )
  125. assert torch.allclose(
  126. model.fc2.weight.grad,
  127. model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
  128. rtol=rtol,
  129. atol=atol,
  130. )
  131. if rank == 0:
  132. assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)