1
0

cross_entropy.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2024, Tri Dao.
  2. import torch
  3. import torch.nn as nn
  4. from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
  5. class CrossEntropyLoss(nn.Module):
  6. def __init__(
  7. self,
  8. ignore_index=-100,
  9. reduction="mean",
  10. label_smoothing=0.0,
  11. logit_scale=1.0,
  12. lse_square_scale=0.0,
  13. inplace_backward=False,
  14. process_group=None,
  15. return_z_loss=False,
  16. ):
  17. """
  18. Arguments:
  19. ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
  20. label_smoothing: float
  21. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
  22. This is also referred to as "z-loss".
  23. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
  24. This saves memory.
  25. process_group: if not None, we're doing Tensor Parallel: each process is responsible for
  26. one part of the vocab. The loss will be aggregated across processes.
  27. return_z_loss: bool. If True, we return the component of the loss contributed by
  28. the lse_square_scale value. This value is only for logging and does not support
  29. backprop.
  30. """
  31. super().__init__()
  32. if reduction not in ["mean", "none", "sum"]:
  33. raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
  34. self.ignore_index = ignore_index
  35. self.reduction = reduction
  36. self.label_smoothing = label_smoothing
  37. self.logit_scale = logit_scale
  38. self.lse_square_scale = lse_square_scale
  39. self.inplace_backward = inplace_backward
  40. self.process_group = process_group
  41. self.return_z_loss = return_z_loss
  42. def forward(self, input, target, precomputed_lse=None):
  43. """
  44. Arguments:
  45. input: (batch, vocab_size)
  46. target: (batch,)
  47. Returns:
  48. losses: (batch,) if reduction is 'none', else (1,), dtype float
  49. z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
  50. """
  51. assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
  52. loss, z_loss = cross_entropy_loss(
  53. input,
  54. target,
  55. precomputed_lse=precomputed_lse,
  56. label_smoothing=self.label_smoothing,
  57. logit_scale=self.logit_scale,
  58. lse_square_scale=self.lse_square_scale,
  59. ignore_index=self.ignore_index,
  60. inplace_backward=self.inplace_backward,
  61. process_group=self.process_group,
  62. )
  63. if self.reduction == "mean":
  64. loss = loss.sum() / (target != self.ignore_index).sum()
  65. elif self.reduction == "sum":
  66. loss = loss.sum()
  67. else:
  68. loss = loss
  69. if not self.return_z_loss:
  70. return loss
  71. if self.reduction == "mean":
  72. z_loss = z_loss.sum() / (target != self.ignore_index).sum()
  73. elif self.reduction == "sum":
  74. z_loss = z_loss.sum()
  75. else:
  76. z_loss = z_loss
  77. return loss, z_loss