cross_entropy_parallel.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
  2. # But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
  3. # the losses we can get the global loss. There's no need to do it step by step
  4. # (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
  5. import torch
  6. import torch.nn as nn
  7. import xentropy_cuda_lib
  8. from apex.transformer.parallel_state import get_tensor_model_parallel_group
  9. from apex.transformer.parallel_state import get_tensor_model_parallel_rank
  10. from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
  11. from apex.transformer.tensor_parallel.utils import VocabUtility
  12. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  13. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  14. # version of PyTorch. The following 4 lines are for backward compatibility with
  15. # older PyTorch.
  16. if "all_gather_into_tensor" not in dir(torch.distributed):
  17. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  18. if "reduce_scatter_tensor" not in dir(torch.distributed):
  19. torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
  20. class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
  21. @staticmethod
  22. def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
  23. inplace_backward=False):
  24. """
  25. logits_parallel: (batch, vocab_size / world_size)
  26. labels: (batch,)
  27. """
  28. assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
  29. batch, partition_vocab_size = logits_parallel.shape
  30. assert labels.shape == (batch,)
  31. rank = get_tensor_model_parallel_rank()
  32. world_size = get_tensor_model_parallel_world_size()
  33. if world_size == 1:
  34. losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
  35. losses.masked_fill_(labels==ignored_index, 0)
  36. labels_local = labels
  37. else:
  38. vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
  39. partition_vocab_size, get_tensor_model_parallel_rank(),
  40. get_tensor_model_parallel_world_size()
  41. )
  42. # Create a mask of valid vocab ids (1 means it needs to be masked).
  43. labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
  44. ignored_mask = labels == ignored_index
  45. labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
  46. masked_labels = labels_local.clone()
  47. masked_labels[labels_mask] = ignored_index
  48. losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
  49. assert lse_local.shape == (batch,)
  50. assert losses.shape == (batch,)
  51. losses.masked_fill_(masked_labels==ignored_index, 0)
  52. lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
  53. device=lse_local.device)
  54. handle_lse = torch.distributed.all_gather_into_tensor(
  55. lse_allgather, lse_local.contiguous(),
  56. group=get_tensor_model_parallel_group(), async_op=True
  57. )
  58. handle_losses = torch.distributed.all_reduce(
  59. losses, op=torch.distributed.ReduceOp.SUM,
  60. group=get_tensor_model_parallel_group(), async_op=True
  61. )
  62. handle_lse.wait()
  63. lse = torch.logsumexp(lse_allgather, dim=0)
  64. # The losses are going to be lse_local - predicted_logit, we just have to subtract
  65. # the lse_local and add the lse (global).
  66. rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor')
  67. lse_local = lse_allgather[rank_per_sample,
  68. torch.arange(batch, device=lse_allgather.device)]
  69. handle_losses.wait()
  70. losses += lse - lse_local
  71. losses.masked_fill_(ignored_mask, 0)
  72. ctx.save_for_backward(logits_parallel, lse, labels_local)
  73. ctx.smoothing = smoothing
  74. ctx.ignored_index = ignored_index
  75. ctx.inplace_backward = inplace_backward
  76. return losses
  77. @staticmethod
  78. def backward(ctx, grad_loss):
  79. logits_parallel, lse, labels = ctx.saved_tensors
  80. if not grad_loss.is_contiguous():
  81. grad_loss = grad_loss.contiguous()
  82. grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
  83. grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
  84. ctx.smoothing, ctx.inplace_backward)
  85. return grad_logits, None, None, None, None, None
  86. class CrossEntropyLossParallel(nn.Module):
  87. def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
  88. inplace_backward=False):
  89. super().__init__()
  90. if reduction not in ['mean', 'none']:
  91. raise NotImplementedError("Only support reduction = 'mean' or 'none'")
  92. self.ignore_index = ignore_index
  93. self.reduction = reduction
  94. self.label_smoothing = label_smoothing
  95. self.inplace_backward = inplace_backward
  96. def forward(self, input, target):
  97. assert input.is_cuda and target.is_cuda
  98. # SoftmaxCrossEntropyLoss implicitly casts to float
  99. loss = SoftmaxCrossEntropyLossParallelFn.apply(
  100. input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
  101. )
  102. if self.reduction == 'mean':
  103. return loss.sum() / (target != self.ignore_index).sum()
  104. else:
  105. return loss