1
0

test_cross_entropy_parallel.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Run test with:
  2. # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel.py
  3. import math
  4. import pytest
  5. import torch
  6. from apex.transformer import parallel_state, tensor_parallel
  7. from flash_attn.losses.cross_entropy import CrossEntropyLoss
  8. is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
  9. @pytest.mark.parametrize(
  10. "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
  11. )
  12. # @pytest.mark.parametrize("dtype", [torch.float16])
  13. @pytest.mark.parametrize("inplace_backward", [False, True])
  14. # @pytest.mark.parametrize("inplace_backward", [False])
  15. @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
  16. # @pytest.mark.parametrize("lse_square_scale", [0.0])
  17. @pytest.mark.parametrize("logit_scale", [0.7])
  18. # @pytest.mark.parametrize("logit_scale", [1.0])
  19. @pytest.mark.parametrize("smoothing", [0.0, 0.9])
  20. # @pytest.mark.parametrize("smoothing", [0.0])
  21. @pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split
  22. # @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
  23. # @pytest.mark.parametrize("world_size", [1, 2])
  24. @pytest.mark.parametrize("world_size", [2])
  25. def test_cross_entropy_loss_parallel(
  26. vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
  27. ):
  28. assert vocab_size % world_size == 0
  29. rtol, atol = (
  30. (1e-5, 2e-5)
  31. if dtype == torch.float32
  32. else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
  33. )
  34. if not torch.distributed.is_initialized():
  35. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  36. partition_vocab_size = vocab_size // world_size
  37. device = f"cuda:{torch.distributed.get_rank()}"
  38. assert world_size <= torch.distributed.get_world_size()
  39. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  40. rank = parallel_state.get_tensor_model_parallel_rank()
  41. # set seed
  42. torch.random.manual_seed(0)
  43. batch_size = 8
  44. seqlen = 128
  45. x_pt = (
  46. torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
  47. ).requires_grad_()
  48. x = (
  49. tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
  50. .detach()
  51. .clone()
  52. .requires_grad_()
  53. )
  54. y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
  55. y[torch.randperm(batch_size * seqlen)[:10]] = -100
  56. model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
  57. model = CrossEntropyLoss(
  58. label_smoothing=smoothing,
  59. logit_scale=logit_scale,
  60. reduction="none",
  61. lse_square_scale=lse_square_scale,
  62. inplace_backward=inplace_backward,
  63. process_group=parallel_state.get_tensor_model_parallel_group(),
  64. )
  65. out = model(x, y)
  66. out_pt = model_pt(x_pt.float() * logit_scale, y)
  67. if lse_square_scale > 0.0:
  68. lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
  69. out_pt += lse_square_scale * lse_pt.square()
  70. out_pt.masked_fill_(y == -100, 0.0)
  71. assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
  72. g = torch.randn_like(out)
  73. out_pt.backward(g)
  74. out.backward(g)
  75. assert torch.allclose(
  76. x.grad,
  77. x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
  78. rtol=rtol,
  79. atol=atol,
  80. )
  81. parallel_state.destroy_model_parallel()