cross_entropy.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
  2. # But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
  3. # the losses we can get the global loss. There's no need to do it step by step
  4. # (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
  5. # The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
  6. import torch
  7. import torch.nn as nn
  8. import xentropy_cuda_lib
  9. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  10. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  11. # version of PyTorch. The following 2 lines are for backward compatibility with
  12. # older PyTorch.
  13. if "all_gather_into_tensor" not in dir(torch.distributed):
  14. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  15. class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
  16. @staticmethod
  17. def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
  18. process_group=None):
  19. """
  20. logits: (batch, vocab_size)
  21. labels: (batch,)
  22. If process_group is not None, we're doing Tensor Parallel: each process is responsible for
  23. one part of the vocab. The loss needs to be aggregated across processes.
  24. """
  25. batch, vocab_size = logits.shape
  26. assert labels.shape == (batch,)
  27. world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
  28. ctx.total_classes = world_size * vocab_size
  29. if world_size == 1:
  30. losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
  31. losses.masked_fill_(labels==ignored_index, 0)
  32. labels_local = labels
  33. else:
  34. rank = torch.distributed.get_rank(process_group)
  35. vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
  36. # Create a mask of valid vocab ids (1 means it needs to be masked).
  37. labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
  38. ignored_mask = labels == ignored_index
  39. labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
  40. # For tensor parallel cross entropy with smoothing, we want to pass in the total number
  41. # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
  42. # last dimension of the input tensor.
  43. losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
  44. world_size * vocab_size)
  45. assert lse_local.shape == (batch,)
  46. assert losses.shape == (batch,)
  47. losses.masked_fill_(ignored_mask, 0)
  48. # For labels == ignored_index, the loss is always 0.
  49. # If there's no smoothing, if labels are in the vocab of this partition, losses contains
  50. # lse_local - predicted logit, and 0 otherwise.
  51. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
  52. # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
  53. # For labels not in the vocab of this partition, losses contains
  54. # 0.1 * (lse_local - sum logit / total_classes).
  55. lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
  56. device=lse_local.device)
  57. torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
  58. group=process_group)
  59. handle_losses = torch.distributed.all_reduce(
  60. losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
  61. )
  62. lse = torch.logsumexp(lse_allgather, dim=0)
  63. # If there's no smoothing, the total losses are lse_local - predicted_logit,
  64. # we just have to subtract the lse_local and add the lse (global).
  65. # If there's smoothing=0.1, the total losses are
  66. # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
  67. # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
  68. rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
  69. lse_local = lse_allgather[rank_per_sample,
  70. torch.arange(batch, device=lse_allgather.device)]
  71. handle_losses.wait()
  72. if smoothing == 0.0:
  73. losses += lse - lse_local
  74. else:
  75. losses += ((1 - smoothing) * (lse - lse_local)
  76. + smoothing * (lse - lse_allgather.sum(dim=0)))
  77. losses.masked_fill_(ignored_mask, 0)
  78. ctx.save_for_backward(logits, lse, labels_local)
  79. ctx.smoothing = smoothing
  80. ctx.ignored_index = ignored_index
  81. ctx.inplace_backward = inplace_backward
  82. return losses
  83. @staticmethod
  84. def backward(ctx, grad_loss):
  85. logits, lse, labels = ctx.saved_tensors
  86. grad_loss = grad_loss.contiguous()
  87. grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
  88. grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
  89. ctx.smoothing, ctx.inplace_backward,
  90. ctx.total_classes)
  91. return grad_logits, None, None, None, None, None, None
  92. class CrossEntropyLoss(nn.Module):
  93. def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
  94. inplace_backward=False):
  95. super().__init__()
  96. if reduction not in ['mean', 'none']:
  97. raise NotImplementedError("Only support reduction = 'mean' or 'none'")
  98. self.ignore_index = ignore_index
  99. self.reduction = reduction
  100. self.label_smoothing = label_smoothing
  101. self.inplace_backward = inplace_backward
  102. def forward(self, input, target, process_group=None):
  103. assert input.is_cuda and target.is_cuda
  104. # SoftmaxCrossEntropyLoss implicitly casts to float
  105. loss = SoftmaxCrossEntropyLossFn.apply(
  106. input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
  107. process_group
  108. )
  109. if self.reduction == 'mean':
  110. return loss.sum() / (target != self.ignore_index).sum()
  111. else:
  112. return loss