test_gpt_parallel.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
  3. import math
  4. import pytest
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from apex.transformer import parallel_state
  9. from einops import rearrange
  10. from flash_attn.losses.cross_entropy import CrossEntropyLoss
  11. from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
  12. from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
  13. from transformers import GPT2Config
  14. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  15. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  16. # @pytest.mark.parametrize('dtype', [torch.bfloat16])
  17. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  18. # @pytest.mark.parametrize('world_size', [2])
  19. @pytest.mark.parametrize("sequence_parallel", [True, False])
  20. # @pytest.mark.parametrize('sequence_parallel', [False])
  21. @pytest.mark.parametrize("has_pos_emb", [True, False])
  22. # @pytest.mark.parametrize('has_pos_emb', [True])
  23. @pytest.mark.parametrize("dim", [1024])
  24. def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
  25. head_dim = 64
  26. assert dim % head_dim == 0
  27. num_heads = dim // head_dim
  28. assert num_heads % world_size == 0
  29. vocab_size = 50264
  30. assert vocab_size % world_size == 0
  31. num_layers = 2
  32. rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
  33. if not torch.distributed.is_initialized():
  34. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  35. device = f"cuda:{torch.distributed.get_rank()}"
  36. assert world_size <= torch.distributed.get_world_size()
  37. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  38. rank = parallel_state.get_tensor_model_parallel_rank()
  39. process_group = parallel_state.get_tensor_model_parallel_group()
  40. # set seed
  41. torch.random.manual_seed(0)
  42. batch_size = 8
  43. seqlen = 1024
  44. assert (batch_size * seqlen) % world_size == 0
  45. input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)
  46. # We need to generate g here so that all processes get the same gradient,
  47. # as rank 0 will have an extra bias that changes the RNG.
  48. g = torch.randn(batch_size * seqlen, device=device)
  49. config = GPT2Config(
  50. n_embd=dim,
  51. n_head=num_heads,
  52. n_layer=num_layers,
  53. n_positions=seqlen if has_pos_emb else 0,
  54. vocab_size=50257,
  55. resid_pdrop=0.0,
  56. embd_pdrop=0.0,
  57. attn_pdrop=0.0,
  58. scale_attn_by_inverse_layer_idx=True,
  59. use_flash_attn=True,
  60. fused_mlp=True,
  61. fused_bias_fc=True,
  62. fused_dropout_add_ln=True,
  63. residual_in_fp32=True,
  64. rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
  65. pad_vocab_size_multiple=8 * world_size,
  66. sequence_parallel=sequence_parallel,
  67. )
  68. config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
  69. model_pt = GPTLMHeadModel(config, device=device)
  70. def init_layer_norm(module):
  71. if isinstance(module, nn.LayerNorm):
  72. nn.init.normal_(module.weight)
  73. nn.init.normal_(module.bias)
  74. model_pt.apply(init_layer_norm)
  75. model = GPTLMHeadModel(config, process_group=process_group, device=device)
  76. total_nparams = sum(p.numel() for p in model_pt.parameters())
  77. sharded_nparams = sum(p.numel() for p in model.parameters())
  78. sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
  79. torch.distributed.all_gather_into_tensor(
  80. sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
  81. )
  82. shared_nparams = sum(
  83. p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
  84. )
  85. shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
  86. torch.distributed.all_gather_into_tensor(
  87. shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
  88. )
  89. assert torch.all(shared_nparams_all == shared_nparams)
  90. assert total_nparams == (
  91. (sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
  92. )
  93. # vocab_size has been rounded up here
  94. partition_vocab_size = config.vocab_size // world_size
  95. partition_dim = dim // world_size
  96. partition_hidden_dim = 4 * dim // world_size
  97. with torch.no_grad():
  98. model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
  99. model.tie_weights()
  100. with torch.autocast(device_type="cuda", dtype=dtype):
  101. out = model(input_ids[:, :-1]).logits
  102. if not sequence_parallel:
  103. out = rearrange(out, "b s d -> (b s) d")
  104. out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
  105. partition_batch_dim = batch_size * seqlen // world_size
  106. assert torch.allclose(
  107. out,
  108. out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
  109. rtol=rtol,
  110. atol=atol,
  111. )
  112. loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
  113. loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
  114. loss = loss_fn(out, input_ids[:, 1:].flatten())
  115. loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
  116. assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
  117. loss_pt.backward(g)
  118. loss.backward(g)
  119. allreduce_sequence_parallel_grad(model, process_group)
  120. parallel_state.destroy_model_parallel()
  121. grad_dict = shard_state_dict_tp(
  122. {k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
  123. )
  124. assert torch.allclose(
  125. model.transformer.embeddings.word_embeddings.weight.grad,
  126. grad_dict["transformer.embeddings.word_embeddings.weight"],
  127. rtol=rtol,
  128. atol=atol * 5,
  129. )
  130. if has_pos_emb:
  131. assert torch.allclose(
  132. model.transformer.embeddings.position_embeddings.weight.grad,
  133. grad_dict["transformer.embeddings.position_embeddings.weight"],
  134. rtol=rtol,
  135. atol=atol,
  136. )
  137. assert torch.allclose(
  138. model.transformer.ln_f.weight.grad,
  139. grad_dict["transformer.ln_f.weight"],
  140. rtol=rtol,
  141. atol=atol,
  142. )
  143. assert torch.allclose(
  144. model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
  145. )
  146. for i in range(num_layers):
  147. assert torch.allclose(
  148. model.transformer.layers[i].mixer.Wqkv.weight.grad,
  149. grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
  150. rtol=rtol,
  151. atol=atol * 10,
  152. )
  153. assert torch.allclose(
  154. model.transformer.layers[i].mixer.Wqkv.bias.grad,
  155. grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
  156. rtol=rtol,
  157. atol=atol * 10,
  158. )
  159. assert torch.allclose(
  160. model.transformer.layers[i].mixer.out_proj.weight.grad,
  161. grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
  162. rtol=rtol,
  163. atol=atol * 10,
  164. )
  165. if rank == 0:
  166. assert torch.allclose(
  167. model.transformer.layers[i].mixer.out_proj.bias.grad,
  168. grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
  169. rtol=rtol,
  170. atol=atol * 5,
  171. )
  172. assert torch.allclose(
  173. model.transformer.layers[i].mlp.fc1.weight.grad,
  174. grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
  175. rtol=rtol,
  176. atol=atol * 10,
  177. )
  178. assert torch.allclose(
  179. model.transformer.layers[i].mlp.fc1.bias.grad,
  180. grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
  181. rtol=rtol,
  182. atol=atol * 10,
  183. )
  184. assert torch.allclose(
  185. model.transformer.layers[i].mlp.fc2.weight.grad,
  186. grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
  187. rtol=rtol,
  188. atol=atol * 10,
  189. )
  190. if rank == 0:
  191. assert torch.allclose(
  192. model.transformer.layers[i].mlp.fc2.bias.grad,
  193. grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
  194. rtol=rtol,
  195. atol=atol * 5,
  196. )
  197. assert torch.allclose(
  198. model.transformer.layers[i].norm1.weight.grad,
  199. grad_dict[f"transformer.layers.{i}.norm1.weight"],
  200. rtol=rtol,
  201. atol=atol,
  202. )
  203. assert torch.allclose(
  204. model.transformer.layers[i].norm1.bias.grad,
  205. grad_dict[f"transformer.layers.{i}.norm1.bias"],
  206. rtol=rtol,
  207. atol=atol,
  208. )
  209. assert torch.allclose(
  210. model.transformer.layers[i].norm2.weight.grad,
  211. grad_dict[f"transformer.layers.{i}.norm2.weight"],
  212. rtol=rtol,
  213. atol=atol,
  214. )
  215. assert torch.allclose(
  216. model.transformer.layers[i].norm2.bias.grad,
  217. grad_dict[f"transformer.layers.{i}.norm2.bias"],
  218. rtol=rtol,
  219. atol=atol,
  220. )