cross_entropy_apex.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. import xentropy_cuda_lib
  4. # https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
  5. class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
  6. @staticmethod
  7. def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
  8. losses, max_log_sum_exp = xentropy_cuda_lib.forward(
  9. logits, labels, smoothing)
  10. losses.masked_fill_(labels==padding_idx, 0)
  11. ctx.save_for_backward(logits, max_log_sum_exp, labels)
  12. ctx.smoothing = smoothing
  13. ctx.padding_idx = padding_idx
  14. ctx.inplace_backward = inplace_backward
  15. return losses
  16. @staticmethod
  17. def backward(ctx, grad_loss):
  18. logits, max_log_sum_exp, labels = ctx.saved_tensors
  19. if not grad_loss.is_contiguous():
  20. grad_loss = grad_loss.contiguous()
  21. grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
  22. grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
  23. ctx.smoothing, ctx.inplace_backward)
  24. return grad_logits, None, None, None, None
  25. class CrossEntropyLossApex(nn.Module):
  26. def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
  27. inplace_backward=False):
  28. super().__init__()
  29. if reduction not in ['mean', 'none']:
  30. raise NotImplementedError("Only support reduction = 'mean' or 'none'")
  31. self.ignore_index = ignore_index
  32. self.reduction = reduction
  33. self.label_smoothing = label_smoothing
  34. self.inplace_backward = inplace_backward
  35. def forward(self, input, target):
  36. assert input.is_cuda and target.is_cuda
  37. # SoftmaxCrossEntropyLoss implicitly casts to float
  38. loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
  39. self.ignore_index, self.inplace_backward)
  40. if self.reduction == 'mean':
  41. return loss.sum() / (target != self.ignore_index).sum()
  42. else:
  43. return loss