test_fused_dense_parallel.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py
  3. import math
  4. import pytest
  5. import torch
  6. import torch.nn.functional as F
  7. from apex.transformer import parallel_state, tensor_parallel
  8. from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP
  9. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  10. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  11. # @pytest.mark.parametrize('dtype', [torch.bfloat16])
  12. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  13. # @pytest.mark.parametrize('world_size', [2])
  14. @pytest.mark.parametrize("sequence_parallel", [True, False])
  15. # @pytest.mark.parametrize('sequence_parallel', [False])
  16. @pytest.mark.parametrize("has_bias", [True, False])
  17. # @pytest.mark.parametrize('has_bias', [False])
  18. @pytest.mark.parametrize("out_features", [1024])
  19. @pytest.mark.parametrize("in_features", [4096])
  20. def test_fused_linear_bias(
  21. in_features, out_features, has_bias, sequence_parallel, world_size, dtype
  22. ):
  23. assert out_features % world_size == 0
  24. rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
  25. if not torch.distributed.is_initialized():
  26. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  27. device = f"cuda:{torch.distributed.get_rank()}"
  28. assert world_size <= torch.distributed.get_world_size()
  29. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  30. rank = parallel_state.get_tensor_model_parallel_rank()
  31. # set seed
  32. torch.random.manual_seed(0)
  33. batch_size = 2
  34. seqlen = 512
  35. assert batch_size * seqlen % world_size == 0
  36. x_pt = torch.randn(
  37. batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
  38. )
  39. if sequence_parallel:
  40. x = (
  41. tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
  42. .detach()
  43. .clone()
  44. .requires_grad_()
  45. )
  46. else:
  47. x = x_pt.detach().clone().requires_grad_()
  48. model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
  49. partition_out_features = out_features // world_size
  50. model = ColumnParallelLinear(
  51. in_features,
  52. out_features,
  53. parallel_state.get_tensor_model_parallel_group(),
  54. bias=has_bias,
  55. sequence_parallel=sequence_parallel,
  56. device=device,
  57. dtype=dtype,
  58. )
  59. with torch.no_grad():
  60. model.weight.copy_(
  61. model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
  62. )
  63. if has_bias:
  64. model.bias.copy_(
  65. model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
  66. )
  67. out = model(x)
  68. out_pt = model_pt(x_pt)
  69. assert torch.allclose(
  70. out,
  71. out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features],
  72. rtol=rtol,
  73. atol=atol,
  74. )
  75. # If we don't divide by batch_size, the gradient gets a bit too large.
  76. g = torch.randn_like(out_pt) / 32
  77. out_pt.backward(g)
  78. out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features])
  79. parallel_state.destroy_model_parallel()
  80. partition_batch_dim = batch_size * seqlen // world_size
  81. assert torch.allclose(
  82. x.grad,
  83. x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  84. if sequence_parallel
  85. else x_pt.grad,
  86. rtol=rtol,
  87. atol=atol,
  88. )
  89. # The error for d_weight and d_bias is quite a bit higher
  90. assert torch.allclose(
  91. model.weight.grad,
  92. model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
  93. rtol=rtol,
  94. atol=atol * 10,
  95. )
  96. if has_bias:
  97. assert torch.allclose(
  98. model.bias.grad,
  99. model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
  100. rtol=rtol,
  101. atol=atol * 5,
  102. )
  103. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  104. # @pytest.mark.parametrize('dtype', [torch.bfloat16])
  105. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  106. # @pytest.mark.parametrize('world_size', [2])
  107. @pytest.mark.parametrize("sequence_parallel", [True, False])
  108. # @pytest.mark.parametrize('sequence_parallel', [False])
  109. @pytest.mark.parametrize("has_bias2", [True, False])
  110. # @pytest.mark.parametrize('has_bias2', [True])
  111. @pytest.mark.parametrize("out_features", [4096])
  112. @pytest.mark.parametrize("in_features", [1024])
  113. def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
  114. assert out_features % world_size == 0
  115. rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
  116. if not torch.distributed.is_initialized():
  117. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  118. device = f"cuda:{torch.distributed.get_rank()}"
  119. assert world_size <= torch.distributed.get_world_size()
  120. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  121. rank = parallel_state.get_tensor_model_parallel_rank()
  122. # set seed
  123. torch.random.manual_seed(0)
  124. batch_size = 2
  125. seqlen = 512
  126. assert batch_size * seqlen % world_size == 0
  127. x_pt = torch.randn(
  128. batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
  129. )
  130. # We need to generate g here so that all processes get the same gradient,
  131. # as rank 0 will have an extra bias that changes the RNG.
  132. # If we don't divide by batch_size, the gradient gets a bit too large.
  133. g = torch.randn_like(x_pt) / 32
  134. if sequence_parallel:
  135. x = (
  136. tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
  137. .detach()
  138. .clone()
  139. .requires_grad_()
  140. )
  141. else:
  142. x = x_pt.detach().clone().requires_grad_()
  143. model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
  144. model_pt_fc2 = torch.nn.Linear(
  145. out_features, in_features, bias=has_bias2, device=device, dtype=dtype
  146. )
  147. partition_out_features = out_features // world_size
  148. partition_in_features = in_features // world_size
  149. model = ParallelFusedMLP(
  150. in_features,
  151. out_features,
  152. in_features,
  153. process_group=parallel_state.get_tensor_model_parallel_group(),
  154. bias2=has_bias2 and rank == 0,
  155. sequence_parallel=sequence_parallel,
  156. device=device,
  157. dtype=dtype,
  158. )
  159. with torch.no_grad():
  160. model.fc1.weight.copy_(
  161. model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
  162. )
  163. model.fc1.bias.copy_(
  164. model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
  165. )
  166. model.fc2.weight.copy_(
  167. model_pt_fc2.weight[
  168. :, rank * partition_out_features : (rank + 1) * partition_out_features
  169. ]
  170. )
  171. if has_bias2 and rank == 0:
  172. model.fc2.bias.copy_(model_pt_fc2.bias)
  173. out = model(x)
  174. out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate="tanh"))
  175. partition_batch_dim = batch_size * seqlen // world_size
  176. assert torch.allclose(
  177. out,
  178. out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  179. if sequence_parallel
  180. else out_pt,
  181. rtol=rtol,
  182. atol=atol,
  183. )
  184. out_pt.backward(g)
  185. out.backward(
  186. g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
  187. )
  188. parallel_state.destroy_model_parallel()
  189. assert torch.allclose(
  190. x.grad,
  191. x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  192. if sequence_parallel
  193. else x_pt.grad,
  194. rtol=rtol,
  195. atol=atol,
  196. )
  197. # The error for d_weight and d_bias is quite a bit higher
  198. assert torch.allclose(
  199. model.fc1.weight.grad,
  200. model_pt_fc1.weight.grad[
  201. rank * partition_out_features : (rank + 1) * partition_out_features
  202. ],
  203. rtol=rtol,
  204. atol=atol * 10,
  205. )
  206. assert torch.allclose(
  207. model.fc1.bias.grad,
  208. model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
  209. rtol=rtol,
  210. atol=atol * 5,
  211. )
  212. assert torch.allclose(
  213. model.fc2.weight.grad,
  214. model_pt_fc2.weight.grad[
  215. :, rank * partition_out_features : (rank + 1) * partition_out_features
  216. ],
  217. rtol=rtol,
  218. atol=atol * 10,
  219. )
  220. if has_bias2 and rank == 0:
  221. assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)