test_embedding_parallel.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from apex.transformer import parallel_state
  8. from einops import rearrange
  9. from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
  10. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  11. @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
  12. # @pytest.mark.parametrize('dtype', [torch.bfloat16])
  13. @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
  14. # @pytest.mark.parametrize('world_size', [2])
  15. @pytest.mark.parametrize("sequence_parallel", [True, False])
  16. # @pytest.mark.parametrize('sequence_parallel', [False])
  17. @pytest.mark.parametrize("has_pos_emb", [True, False])
  18. # @pytest.mark.parametrize('has_pos_emb', [True])
  19. @pytest.mark.parametrize("dim", [1024])
  20. def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
  21. vocab_size = 50264
  22. seqlen = 2048
  23. assert vocab_size % world_size == 0
  24. assert dim % world_size == 0
  25. rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
  26. if not torch.distributed.is_initialized():
  27. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  28. device = f"cuda:{torch.distributed.get_rank()}"
  29. assert world_size <= torch.distributed.get_world_size()
  30. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  31. rank = parallel_state.get_tensor_model_parallel_rank()
  32. # set seed
  33. torch.random.manual_seed(0)
  34. batch_size = 8
  35. seqlen = 1024
  36. assert (batch_size * seqlen) % world_size == 0
  37. input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
  38. input_ids = input_ids_pt.detach().clone()
  39. model_pt = GPT2Embeddings(
  40. dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype
  41. )
  42. model = ParallelGPT2Embeddings(
  43. dim,
  44. vocab_size,
  45. seqlen if has_pos_emb else 0,
  46. parallel_state.get_tensor_model_parallel_group(),
  47. sequence_parallel=sequence_parallel,
  48. device=device,
  49. dtype=dtype,
  50. )
  51. partition_vocab_size = vocab_size // world_size
  52. partition_dim = dim // world_size
  53. with torch.no_grad():
  54. model.word_embeddings.weight.copy_(
  55. model_pt.word_embeddings.weight[
  56. rank * partition_vocab_size : (rank + 1) * partition_vocab_size
  57. ]
  58. )
  59. if has_pos_emb:
  60. model.position_embeddings.weight.copy_(
  61. model_pt.position_embeddings.weight[
  62. :, rank * partition_dim : (rank + 1) * partition_dim
  63. ]
  64. )
  65. out = model(input_ids, combine_batch_seqlen_dim=True)
  66. out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d")
  67. partition_batch_dim = batch_size * seqlen // world_size
  68. assert torch.allclose(
  69. out,
  70. out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
  71. if sequence_parallel
  72. else out_pt,
  73. rtol=rtol,
  74. atol=atol,
  75. )
  76. g = torch.randn_like(out_pt)
  77. out_pt.backward(g)
  78. out.backward(
  79. g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
  80. )
  81. parallel_state.destroy_model_parallel()
  82. assert torch.allclose(
  83. model.word_embeddings.weight.grad,
  84. model_pt.word_embeddings.weight.grad[
  85. rank * partition_vocab_size : (rank + 1) * partition_vocab_size
  86. ],
  87. rtol=rtol,
  88. atol=atol,
  89. )
  90. if has_pos_emb:
  91. assert torch.allclose(
  92. model.position_embeddings.weight.grad,
  93. model_pt.position_embeddings.weight.grad[
  94. :, rank * partition_dim : (rank + 1) * partition_dim
  95. ],
  96. rtol=rtol,
  97. atol=atol,
  98. )