cross_entropy.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2023, Tri Dao.
  2. import torch
  3. import torch.nn as nn
  4. from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
  5. class CrossEntropyLoss(nn.Module):
  6. def __init__(
  7. self,
  8. ignore_index=-100,
  9. reduction="mean",
  10. label_smoothing=0.0,
  11. logit_scale=1.0,
  12. lse_square_scale=0.0,
  13. inplace_backward=False,
  14. process_group=None,
  15. return_z_loss=False,
  16. ):
  17. """
  18. Arguments:
  19. ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
  20. label_smoothing: float
  21. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
  22. This is also referred to as "z-loss".
  23. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
  24. This saves memory.
  25. process_group: if not None, we're doing Tensor Parallel: each process is responsible for
  26. one part of the vocab. The loss will be aggregated across processes.
  27. return_z_loss: bool. If True, we return the component of the loss contributed by
  28. the lse_square_scale value. This value is only for logging and does not support
  29. backprop.
  30. """
  31. super().__init__()
  32. if reduction not in ["mean", "none", "sum"]:
  33. raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
  34. self.ignore_index = ignore_index
  35. self.reduction = reduction
  36. self.label_smoothing = label_smoothing
  37. self.logit_scale = logit_scale
  38. self.lse_square_scale = lse_square_scale
  39. self.inplace_backward = inplace_backward
  40. self.process_group = process_group
  41. self.return_z_loss = return_z_loss
  42. def forward(self, input, target):
  43. """
  44. Arguments:
  45. input: (batch, vocab_size)
  46. target: (batch,)
  47. Returns:
  48. losses: (batch,) if reduction is 'none', else (1,), dtype float
  49. z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
  50. """
  51. assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
  52. loss, z_loss = cross_entropy_loss(
  53. input,
  54. target,
  55. label_smoothing=self.label_smoothing,
  56. logit_scale=self.logit_scale,
  57. lse_square_scale=self.lse_square_scale,
  58. ignore_index=self.ignore_index,
  59. inplace_backward=self.inplace_backward,
  60. process_group=self.process_group,
  61. )
  62. if self.reduction == "mean":
  63. loss = loss.sum() / (target != self.ignore_index).sum()
  64. elif self.reduction == "sum":
  65. loss = loss.sum()
  66. else:
  67. loss = loss
  68. if not self.return_z_loss:
  69. return loss
  70. if self.reduction == "mean":
  71. z_loss = z_loss.sum() / (target != self.ignore_index).sum()
  72. elif self.reduction == "sum":
  73. z_loss = z_loss.sum()
  74. else:
  75. z_loss = z_loss
  76. return loss, z_loss