test_block_parallel.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_block_parallel.py
  3. import math
  4. from functools import partial
  5. import pytest
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from apex.transformer import parallel_state, tensor_parallel
  10. from einops import rearrange
  11. from flash_attn.modules.block import Block
  12. from flash_attn.modules.mha import MHA, ParallelMHA
  13. from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
  14. from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
  15. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  16. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  17. # @pytest.mark.parametrize('dtype', [torch.float16])
  18. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  19. # @pytest.mark.parametrize('world_size', [2])
  20. @pytest.mark.parametrize("sequence_parallel", [True, False])
  21. # @pytest.mark.parametrize('sequence_parallel', [True])
  22. @pytest.mark.parametrize("dim", [1024])
  23. def test_block_parallel(dim, sequence_parallel, world_size, dtype):
  24. head_dim = 64
  25. assert dim % head_dim == 0
  26. num_heads = dim // head_dim
  27. assert num_heads % world_size == 0
  28. rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
  29. if not torch.distributed.is_initialized():
  30. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  31. device = f"cuda:{torch.distributed.get_rank()}"
  32. assert world_size <= torch.distributed.get_world_size()
  33. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  34. rank = parallel_state.get_tensor_model_parallel_rank()
  35. # set seed
  36. torch.random.manual_seed(0)
  37. batch_size = 2
  38. seqlen = 1024
  39. assert (batch_size * seqlen) % world_size == 0
  40. x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
  41. residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
  42. # We need to generate g here so that all processes get the same gradient,
  43. # as rank 0 will have an extra bias that changes the RNG.
  44. # If we don't divide by batch_size, the gradient gets a bit too large.
  45. g = torch.randn_like(x_pt) / 32
  46. if sequence_parallel:
  47. x = (
  48. tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
  49. .detach()
  50. .clone()
  51. .requires_grad_()
  52. )
  53. residual = (
  54. tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
  55. .detach()
  56. .clone()
  57. .requires_grad_()
  58. )
  59. else:
  60. x = x_pt.detach().clone().requires_grad_()
  61. residual = residual_pt.detach().clone().requires_grad_()
  62. mixer_cls_pt = partial(
  63. MHA,
  64. num_heads=num_heads,
  65. rotary_emb_dim=int(head_dim // 2),
  66. use_flash_attn=True,
  67. device=device,
  68. dtype=dtype,
  69. )
  70. mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
  71. norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
  72. model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
  73. with torch.no_grad():
  74. nn.init.normal_(model_pt.norm1.weight)
  75. nn.init.normal_(model_pt.norm1.bias)
  76. nn.init.normal_(model_pt.norm2.weight)
  77. nn.init.normal_(model_pt.norm2.bias)
  78. mixer_cls = partial(
  79. ParallelMHA,
  80. num_heads=num_heads,
  81. process_group=parallel_state.get_tensor_model_parallel_group(),
  82. rotary_emb_dim=int(head_dim // 2),
  83. use_flash_attn=True,
  84. sequence_parallel=sequence_parallel,
  85. device=device,
  86. dtype=dtype,
  87. )
  88. mlp_cls = partial(
  89. ParallelFusedMLP,
  90. hidden_features=4 * dim,
  91. process_group=parallel_state.get_tensor_model_parallel_group(),
  92. sequence_parallel=sequence_parallel,
  93. device=device,
  94. dtype=dtype,
  95. )
  96. model = Block(
  97. dim,
  98. mixer_cls,
  99. mlp_cls,
  100. norm_cls,
  101. fused_dropout_add_ln=True,
  102. sequence_parallel=sequence_parallel,
  103. mark_shared_params=True,
  104. )
  105. partition_dim = dim // world_size
  106. partition_hidden_dim = 4 * dim // world_size
  107. with torch.no_grad():
  108. model.mixer.Wqkv.weight.copy_(
  109. rearrange(
  110. rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
  111. :, rank * partition_dim : (rank + 1) * partition_dim
  112. ],
  113. "three o i -> (three o) i",
  114. )
  115. )
  116. model.mixer.Wqkv.bias.copy_(
  117. rearrange(
  118. rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
  119. :, rank * partition_dim : (rank + 1) * partition_dim
  120. ],
  121. "three o -> (three o)",
  122. )
  123. )
  124. model.mixer.out_proj.weight.copy_(
  125. model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
  126. )
  127. if rank == 0:
  128. model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
  129. model.mlp.fc1.weight.copy_(
  130. model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
  131. )
  132. model.mlp.fc1.bias.copy_(
  133. model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
  134. )
  135. model.mlp.fc2.weight.copy_(
  136. model_pt.mlp.fc2.weight[
  137. :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
  138. ]
  139. )
  140. if rank == 0:
  141. model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
  142. model.norm1.weight.copy_(model_pt.norm1.weight)
  143. model.norm1.bias.copy_(model_pt.norm1.bias)
  144. model.norm2.weight.copy_(model_pt.norm2.weight)
  145. model.norm2.bias.copy_(model_pt.norm2.bias)
  146. mixer_kwargs = {"seqlen": seqlen}
  147. out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
  148. out_pt, out_residual_pt = model_pt(
  149. rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
  150. rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
  151. )
  152. out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
  153. partition_batch_dim = batch_size * seqlen // world_size
  154. assert torch.allclose(
  155. out,
  156. out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  157. if sequence_parallel
  158. else out_pt,
  159. rtol=rtol,
  160. atol=atol,
  161. )
  162. assert torch.allclose(
  163. out_residual,
  164. out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  165. if sequence_parallel
  166. else out_residual_pt,
  167. rtol=rtol,
  168. atol=atol,
  169. )
  170. (out_pt + 2 * out_residual_pt).backward(g)
  171. (out + 2 * out_residual).backward(
  172. g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
  173. )
  174. allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
  175. parallel_state.destroy_model_parallel()
  176. assert torch.allclose(
  177. x.grad,
  178. x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  179. if sequence_parallel
  180. else x_pt.grad,
  181. rtol=rtol,
  182. atol=atol / 10, # magnitude of x.grad is quite small
  183. )
  184. assert torch.allclose(
  185. residual.grad,
  186. residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  187. if sequence_parallel
  188. else residual_pt.grad,
  189. rtol=rtol,
  190. atol=atol,
  191. )
  192. # The error for d_weight and d_bias is quite a bit higher
  193. assert torch.allclose(
  194. model.mixer.Wqkv.weight.grad,
  195. rearrange(
  196. rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
  197. :, rank * partition_dim : (rank + 1) * partition_dim
  198. ],
  199. "three o i -> (three o) i",
  200. ),
  201. rtol=rtol,
  202. atol=atol * 10,
  203. )
  204. assert torch.allclose(
  205. model.mixer.Wqkv.bias.grad,
  206. rearrange(
  207. rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
  208. :, rank * partition_dim : (rank + 1) * partition_dim
  209. ],
  210. "three o -> (three o)",
  211. ),
  212. rtol=rtol,
  213. atol=atol * 5,
  214. )
  215. assert torch.allclose(
  216. model.mixer.out_proj.weight.grad,
  217. model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
  218. rtol=rtol,
  219. atol=atol * 10,
  220. )
  221. if rank == 0:
  222. assert torch.allclose(
  223. model.mixer.out_proj.bias.grad,
  224. model_pt.mixer.out_proj.bias.grad,
  225. rtol=rtol,
  226. atol=atol * 5,
  227. )
  228. assert torch.allclose(
  229. model.mlp.fc1.weight.grad,
  230. model_pt.mlp.fc1.weight.grad[
  231. rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
  232. ],
  233. rtol=rtol,
  234. atol=atol * 10,
  235. )
  236. assert torch.allclose(
  237. model.mlp.fc1.bias.grad,
  238. model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
  239. rtol=rtol,
  240. atol=atol * 5,
  241. )
  242. assert torch.allclose(
  243. model.mlp.fc2.weight.grad,
  244. model_pt.mlp.fc2.weight.grad[
  245. :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
  246. ],
  247. rtol=rtol,
  248. atol=atol * 10,
  249. )
  250. if rank == 0:
  251. assert torch.allclose(
  252. model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
  253. )
  254. assert torch.allclose(
  255. model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
  256. )
  257. assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
  258. assert torch.allclose(
  259. model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
  260. )
  261. assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)