1
0

test_cross_entropy_parallel.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel.py
  3. import math
  4. import pytest
  5. import torch
  6. from apex.transformer import parallel_state, tensor_parallel
  7. from flash_attn.losses.cross_entropy import CrossEntropyLoss
  8. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  9. @pytest.mark.parametrize(
  10. "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
  11. )
  12. # @pytest.mark.parametrize("dtype", [torch.float16])
  13. @pytest.mark.parametrize("precompute_lse", [False, True])
  14. # @pytest.mark.parametrize("precompute_lse", [False])
  15. @pytest.mark.parametrize("inplace_backward", [False, True])
  16. # @pytest.mark.parametrize("inplace_backward", [False])
  17. # @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
  18. @pytest.mark.parametrize("lse_square_scale", [1e-2])
  19. @pytest.mark.parametrize("logit_scale", [1.0, 0.7])
  20. # @pytest.mark.parametrize("logit_scale", [1.0])
  21. @pytest.mark.parametrize("smoothing", [0.0, 0.9])
  22. # @pytest.mark.parametrize("smoothing", [0.0])
  23. @pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split
  24. # @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
  25. # @pytest.mark.parametrize("world_size", [1, 2])
  26. @pytest.mark.parametrize("world_size", [2])
  27. def test_cross_entropy_loss_parallel(
  28. vocab_size,
  29. world_size,
  30. smoothing,
  31. logit_scale,
  32. lse_square_scale,
  33. inplace_backward,
  34. precompute_lse,
  35. dtype,
  36. ):
  37. if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):
  38. pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0")
  39. assert vocab_size % world_size == 0
  40. rtol, atol = (
  41. (1e-5, 2e-5)
  42. if dtype == torch.float32
  43. else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
  44. )
  45. if not torch.distributed.is_initialized():
  46. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  47. partition_vocab_size = vocab_size // world_size
  48. device = f"cuda:{torch.distributed.get_rank()}"
  49. assert world_size <= torch.distributed.get_world_size()
  50. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  51. rank = parallel_state.get_tensor_model_parallel_rank()
  52. # set seed
  53. torch.random.manual_seed(0)
  54. batch_size = 8
  55. seqlen = 128
  56. x_pt = (
  57. torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
  58. ).requires_grad_()
  59. x = (
  60. tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
  61. .detach()
  62. .clone()
  63. .requires_grad_()
  64. )
  65. y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
  66. y[torch.randperm(batch_size * seqlen)[:10]] = -100
  67. model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
  68. model = CrossEntropyLoss(
  69. label_smoothing=smoothing,
  70. logit_scale=logit_scale,
  71. reduction="none",
  72. lse_square_scale=lse_square_scale,
  73. inplace_backward=inplace_backward,
  74. process_group=parallel_state.get_tensor_model_parallel_group(),
  75. )
  76. if precompute_lse:
  77. with torch.no_grad():
  78. lse = torch.logsumexp(x.float(), dim=-1)
  79. else:
  80. lse = None
  81. out = model(x, y, precomputed_lse=lse)
  82. out_pt = model_pt(x_pt.float() * logit_scale, y)
  83. if lse_square_scale > 0.0:
  84. lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
  85. out_pt += lse_square_scale * lse_pt.square()
  86. out_pt.masked_fill_(y == -100, 0.0)
  87. assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
  88. g = torch.randn_like(out)
  89. out_pt.backward(g)
  90. out.backward(g)
  91. assert torch.allclose(
  92. x.grad,
  93. x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
  94. rtol=rtol,
  95. atol=atol,
  96. )
  97. parallel_state.destroy_model_parallel()