Browse Source

Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss

Tri Dao 2 years ago
parent
commit
dff68c2b22

+ 19 - 11
csrc/xentropy/interface.cpp

@@ -4,7 +4,8 @@
 std::vector<at::Tensor> softmax_xentropy_cuda(
     const at::Tensor &input,
     const at::Tensor &labels,
-    const float smoothing);
+    const float smoothing,
+    const int total_classes);
 
 at::Tensor softmax_xentropy_backward_cuda(
     const at::Tensor &grad_loss,
@@ -12,7 +13,8 @@ at::Tensor softmax_xentropy_backward_cuda(
     const at::Tensor &max_log_sum_exp,
     const at::Tensor &labels,
     const float smoothing,
-    const bool inplace);
+    const bool inplace,
+    const int total_classes);
 
 // C++ interface
 
@@ -23,11 +25,15 @@ at::Tensor softmax_xentropy_backward_cuda(
 std::vector<at::Tensor> softmax_xentropy_forward(
     const at::Tensor &input,
     const at::Tensor &labels,
-    const float smoothing) {
-    CHECK_CUDA(input);
+    const float smoothing,
+    const int total_classes=-1) {
+    // For tensor parallel cross entropy with smoothing, we want to pass in the total number
+    // of classes so that smoothing can be applied correctly. If total_classes=-1, use the
+    // last dimension of the input tensor.
+    CHECK_INPUT(input);
     CHECK_INPUT(labels);
 
-    return softmax_xentropy_cuda(input, labels, smoothing);
+    return softmax_xentropy_cuda(input, labels, smoothing, total_classes);
 }
 
 at::Tensor softmax_xentropy_backward(
@@ -36,16 +42,18 @@ at::Tensor softmax_xentropy_backward(
     const at::Tensor &max_log_sum_exp,
     const at::Tensor &labels,
     const float smoothing,
-    const bool inplace)  {
-    CHECK_CUDA(grad_loss);
-    CHECK_CUDA(logits);
+    const bool inplace,
+    const int total_classes=-1)  {
+    CHECK_INPUT(grad_loss);
+    CHECK_INPUT(logits);
     CHECK_INPUT(max_log_sum_exp);
     CHECK_INPUT(labels);
 
-    return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
+    return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels,
+                                          smoothing, inplace, total_classes);
 }
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)");
-    m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)");
+    m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1);
+    m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1);
 }

+ 28 - 22
csrc/xentropy/xentropy_kernel.cu

@@ -434,7 +434,8 @@ cunn_SoftMaxXEntropyForward(
     scalar_t *input,
     int64_t *labels,
     int64_t classes,
-    const float smoothing)
+    const float smoothing,
+    const int total_classes)
 {
   extern __shared__ unsigned char smem[];
   auto sdata = reinterpret_cast<accscalar_t*>(smem);
@@ -472,12 +473,8 @@ cunn_SoftMaxXEntropyForward(
   // reserve max + log_sum_exp for bprop
   if (threadIdx.x == 0) {
     accscalar_t lse = max_k + std::log(sumAll);
-    if ((label >= 0) && (label < classes)) {
-      accscalar_t log_prob = epilogue(static_cast<accscalar_t>(input[label]));
-      losses[blockIdx.x] = (lse - sum_k / classes) * smoothing - log_prob * (1 - smoothing);
-    } else {
-      losses[blockIdx.x] = outscalar_t(0.f);
-    }
+    accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f;
+    losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing);
     max_log_sum_exp[blockIdx.x] = lse;
   }
 }
@@ -490,10 +487,11 @@ apply(scalar_t *gradInput,
       outscalar_t *gradOutput,
       int64_t *labels,
       const float smoothing,
-      int classes)
+      int classes,
+      const int total_classes)
 {
   accscalar_t smooth_positives = 1.0 - smoothing;
-  accscalar_t smooth_negatives = smoothing / classes;
+  accscalar_t smooth_negatives = smoothing / total_classes;
   accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
   int64_t label = labels[blockIdx.x];
   accscalar_t coeff = max_log_sum_exp[blockIdx.x];
@@ -534,10 +532,11 @@ aligned_apply(int shift,
               outscalar_t *gradOutput,
               int64_t *labels,
               const float smoothing,
-              int classes)
+              int classes,
+              const int total_classes)
 {
   accscalar_t smooth_positives = 1.0 - smoothing;
-  accscalar_t smooth_negatives = smoothing / classes;
+  accscalar_t smooth_negatives = smoothing / total_classes;
   accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
   int64_t label = labels[blockIdx.x];
   accscalar_t coeff = max_log_sum_exp[blockIdx.x];
@@ -602,7 +601,8 @@ cunn_SoftMaxXEntropyBackward(
     outscalar_t *gradOutput,
     int64_t *labels,
     const float smoothing,
-    int classes)
+    int classes,
+    const int total_classes)
 {
   gradInput += blockIdx.x * classes;
   logits += blockIdx.x * classes;
@@ -611,10 +611,10 @@ cunn_SoftMaxXEntropyBackward(
   const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
   const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
   if (shift == shift_){
-    aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
+    aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
   }
   else {
-    apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
+    apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes);
   }
 
 }
@@ -623,7 +623,11 @@ template<template<typename, typename, typename> class Epilogue>
 std::vector<Tensor> host_softmax_xentropy(
         const Tensor & input_,
         const Tensor & labels_,
-        const float smoothing){
+        const float smoothing,
+        const int total_classes) {
+  // For tensor parallel cross entropy with smoothing, we want to pass in the total number
+  // of classes so that smoothing can be applied correctly. If total_classes=-1, use the
+  // last dimension of the input tensor.
   AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");
 
   // Otherwise the kernel will be launched from cuda:0 device
@@ -666,7 +670,7 @@ std::vector<Tensor> host_softmax_xentropy(
       <<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
         losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(),
         input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(),
-        dim_size, smoothing
+        dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes
     );
   );
 
@@ -683,7 +687,8 @@ Tensor host_softmax_xentropy_backward(
     const at::Tensor &max_log_sum_exp,
     const at::Tensor &labels,
     const float smoothing,
-    bool inplace) {
+    bool inplace,
+    const int total_classes) {
   // Otherwise the kernel will be launched from cuda:0 device
   // Cast to char to avoid compiler warning about narrowing
   at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
@@ -730,7 +735,7 @@ Tensor host_softmax_xentropy_backward(
         gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(),
         max_log_sum_exp.data_ptr<accscalar_t>(),
         grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(),
-        smoothing, dim_size
+        smoothing, dim_size, total_classes
     );
   );
 
@@ -738,8 +743,8 @@ Tensor host_softmax_xentropy_backward(
   return gI;
 }
 
-std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing){
-  return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing);
+std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){
+  return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes);
 }
 
 at::Tensor softmax_xentropy_backward_cuda(
@@ -748,7 +753,8 @@ at::Tensor softmax_xentropy_backward_cuda(
     const at::Tensor &max_log_sum_exp,
     const at::Tensor &labels,
     const float smoothing,
-    const bool inplace) {
+    const bool inplace,
+    const int total_classes) {
   AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float");
-  return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace);
+  return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes);
 }

+ 128 - 0
flash_attn/losses/cross_entropy.py

@@ -0,0 +1,128 @@
+# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
+# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
+# the losses we can get the global loss. There's no need to do it step by step
+# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
+# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
+import torch
+import torch.nn as nn
+
+import xentropy_cuda_lib
+
+# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
+# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
+# version of PyTorch. The following 2 lines are for backward compatibility with
+# older PyTorch.
+if "all_gather_into_tensor" not in dir(torch.distributed):
+    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
+
+
+class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
+                process_group=None):
+        """
+        logits: (batch, vocab_size)
+        labels: (batch,)
+        If process_group is not None, we're doing Tensor Parallel: each process is responsible for
+        one part of the vocab. The loss needs to be aggregated across processes.
+        """
+        batch, vocab_size = logits.shape
+        assert labels.shape == (batch,)
+        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
+        ctx.total_classes = world_size * vocab_size
+
+        if world_size == 1:
+            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
+            losses.masked_fill_(labels==ignored_index, 0)
+            labels_local = labels
+        else:
+            rank = torch.distributed.get_rank(process_group)
+            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
+
+            # Create a mask of valid vocab ids (1 means it needs to be masked).
+            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
+            ignored_mask = labels == ignored_index
+            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
+
+            # For tensor parallel cross entropy with smoothing, we want to pass in the total number
+            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
+            # last dimension of the input tensor.
+            losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
+                                                          world_size * vocab_size)
+            assert lse_local.shape == (batch,)
+            assert losses.shape == (batch,)
+            losses.masked_fill_(ignored_mask, 0)
+            # For labels == ignored_index, the loss is always 0.
+            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
+            # lse_local - predicted logit, and 0 otherwise.
+            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
+            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
+            # For labels not in the vocab of this partition, losses contains
+            # 0.1 * (lse_local - sum logit / total_classes).
+
+            lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
+                                        device=lse_local.device)
+            torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
+                                                     group=process_group)
+            handle_losses = torch.distributed.all_reduce(
+                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
+            )
+            lse = torch.logsumexp(lse_allgather, dim=0)
+            # If there's no smoothing, the total losses are lse_local - predicted_logit,
+            # we just have to subtract the lse_local and add the lse (global).
+            # If there's smoothing=0.1, the total losses are
+            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
+            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
+            rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
+            lse_local = lse_allgather[rank_per_sample,
+                                      torch.arange(batch, device=lse_allgather.device)]
+
+            handle_losses.wait()
+            if smoothing == 0.0:
+                losses += lse - lse_local
+            else:
+                losses += ((1 - smoothing) * (lse - lse_local)
+                           + smoothing * (lse - lse_allgather.sum(dim=0)))
+            losses.masked_fill_(ignored_mask, 0)
+
+        ctx.save_for_backward(logits, lse, labels_local)
+        ctx.smoothing = smoothing
+        ctx.ignored_index = ignored_index
+        ctx.inplace_backward = inplace_backward
+        return losses
+
+    @staticmethod
+    def backward(ctx, grad_loss):
+        logits, lse, labels = ctx.saved_tensors
+        grad_loss = grad_loss.contiguous()
+        grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
+        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
+                                                 ctx.smoothing, ctx.inplace_backward,
+                                                 ctx.total_classes)
+        return grad_logits, None, None, None, None, None, None
+
+
+class CrossEntropyLoss(nn.Module):
+
+    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
+                 inplace_backward=False):
+        super().__init__()
+        if reduction not in ['mean', 'none']:
+            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
+        self.ignore_index = ignore_index
+        self.reduction = reduction
+        self.label_smoothing = label_smoothing
+        self.inplace_backward = inplace_backward
+
+    def forward(self, input, target, process_group=None):
+        assert input.is_cuda and target.is_cuda
+        # SoftmaxCrossEntropyLoss implicitly casts to float
+        loss = SoftmaxCrossEntropyLossFn.apply(
+            input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
+            process_group
+        )
+        if self.reduction == 'mean':
+            return loss.sum() / (target != self.ignore_index).sum()
+        else:
+            return loss

+ 0 - 51
flash_attn/losses/cross_entropy_apex.py

@@ -1,51 +0,0 @@
-import torch
-import torch.nn as nn
-
-import xentropy_cuda_lib
-
-
-# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
-class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
-        losses, max_log_sum_exp = xentropy_cuda_lib.forward(
-            logits, labels, smoothing)
-        losses.masked_fill_(labels==padding_idx, 0)
-        ctx.save_for_backward(logits, max_log_sum_exp, labels)
-        ctx.smoothing = smoothing
-        ctx.padding_idx = padding_idx
-        ctx.inplace_backward = inplace_backward
-        return losses
-
-    @staticmethod
-    def backward(ctx, grad_loss):
-        logits, max_log_sum_exp, labels = ctx.saved_tensors
-        if not grad_loss.is_contiguous():
-            grad_loss = grad_loss.contiguous()
-        grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
-        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
-                                                 ctx.smoothing, ctx.inplace_backward)
-        return grad_logits, None, None, None, None
-
-
-class CrossEntropyLossApex(nn.Module):
-
-    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
-                 inplace_backward=False):
-        super().__init__()
-        if reduction not in ['mean', 'none']:
-            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
-        self.ignore_index = ignore_index
-        self.reduction = reduction
-        self.label_smoothing = label_smoothing
-        self.inplace_backward = inplace_backward
-
-    def forward(self, input, target):
-        assert input.is_cuda and target.is_cuda
-        # SoftmaxCrossEntropyLoss implicitly casts to float
-        loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
-                                               self.ignore_index, self.inplace_backward)
-        if self.reduction == 'mean':
-            return loss.sum() / (target != self.ignore_index).sum()
-        else:
-            return loss

+ 0 - 122
flash_attn/losses/cross_entropy_parallel.py

@@ -1,122 +0,0 @@
-# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
-# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
-# the losses we can get the global loss. There's no need to do it step by step
-# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
-import torch
-import torch.nn as nn
-
-import xentropy_cuda_lib
-
-from apex.transformer.parallel_state import get_tensor_model_parallel_group
-from apex.transformer.parallel_state import get_tensor_model_parallel_rank
-from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
-from apex.transformer.tensor_parallel.utils import VocabUtility
-
-# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
-# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
-# version of PyTorch. The following 4 lines are for backward compatibility with
-# older PyTorch.
-if "all_gather_into_tensor" not in dir(torch.distributed):
-    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
-if "reduce_scatter_tensor" not in dir(torch.distributed):
-    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
-
-
-class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
-
-    @staticmethod
-    def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
-                inplace_backward=False):
-        """
-        logits_parallel: (batch, vocab_size / world_size)
-        labels: (batch,)
-        """
-        assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
-        batch, partition_vocab_size = logits_parallel.shape
-        assert labels.shape == (batch,)
-        rank = get_tensor_model_parallel_rank()
-        world_size = get_tensor_model_parallel_world_size()
-
-        if world_size == 1:
-            losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
-            losses.masked_fill_(labels==ignored_index, 0)
-            labels_local = labels
-        else:
-            vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
-                partition_vocab_size, get_tensor_model_parallel_rank(),
-                get_tensor_model_parallel_world_size()
-            )
-
-            # Create a mask of valid vocab ids (1 means it needs to be masked).
-            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
-            ignored_mask = labels == ignored_index
-            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
-            masked_labels = labels_local.clone()
-            masked_labels[labels_mask] = ignored_index
-
-            losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
-            assert lse_local.shape == (batch,)
-            assert losses.shape == (batch,)
-            losses.masked_fill_(masked_labels==ignored_index, 0)
-
-            lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
-                                        device=lse_local.device)
-            handle_lse = torch.distributed.all_gather_into_tensor(
-                lse_allgather, lse_local.contiguous(),
-                group=get_tensor_model_parallel_group(), async_op=True
-            )
-            handle_losses = torch.distributed.all_reduce(
-                losses, op=torch.distributed.ReduceOp.SUM,
-                group=get_tensor_model_parallel_group(), async_op=True
-            )
-            handle_lse.wait()
-            lse = torch.logsumexp(lse_allgather, dim=0)
-            # The losses are going to be lse_local - predicted_logit, we just have to subtract
-            # the lse_local and add the lse (global).
-            rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor')
-            lse_local = lse_allgather[rank_per_sample,
-                                      torch.arange(batch, device=lse_allgather.device)]
-
-            handle_losses.wait()
-            losses += lse - lse_local
-            losses.masked_fill_(ignored_mask, 0)
-
-        ctx.save_for_backward(logits_parallel, lse, labels_local)
-        ctx.smoothing = smoothing
-        ctx.ignored_index = ignored_index
-        ctx.inplace_backward = inplace_backward
-        return losses
-
-    @staticmethod
-    def backward(ctx, grad_loss):
-        logits_parallel, lse, labels = ctx.saved_tensors
-        if not grad_loss.is_contiguous():
-            grad_loss = grad_loss.contiguous()
-        grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
-        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
-                                                 ctx.smoothing, ctx.inplace_backward)
-        return grad_logits, None, None, None, None, None
-
-
-class CrossEntropyLossParallel(nn.Module):
-
-    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
-                 inplace_backward=False):
-        super().__init__()
-        if reduction not in ['mean', 'none']:
-            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
-        self.ignore_index = ignore_index
-        self.reduction = reduction
-        self.label_smoothing = label_smoothing
-        self.inplace_backward = inplace_backward
-
-    def forward(self, input, target):
-        assert input.is_cuda and target.is_cuda
-        # SoftmaxCrossEntropyLoss implicitly casts to float
-        loss = SoftmaxCrossEntropyLossParallelFn.apply(
-            input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
-        )
-        if self.reduction == 'mean':
-            return loss.sum() / (target != self.ignore_index).sum()
-        else:
-            return loss

+ 4 - 4
flash_attn/models/bert.py

@@ -40,9 +40,9 @@ except ImportError:
     dropout_add_layer_norm, layer_norm = None, None
 
 try:
-    from flash_attn.losses.cross_entropy_apex import CrossEntropyLossApex
+    from flash_attn.losses.cross_entropy import CrossEntropyLoss
 except ImportError:
-    CrossEntropyLossApex = None
+    CrossEntropyLoss = None
 
 
 logger = logging.getLogger(__name__)
@@ -374,10 +374,10 @@ class BertForPreTraining(BertPreTrainedModel):
         if self.last_layer_subset:
             assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
         use_xentropy = getattr(config, 'use_xentropy', False)
-        if use_xentropy and CrossEntropyLossApex is None:
+        if use_xentropy and CrossEntropyLoss is None:
             raise ImportError('xentropy_cuda is not installed')
         loss_cls = (nn.CrossEntropyLoss if not use_xentropy
-                    else partial(CrossEntropyLossApex, inplace_backward=True))
+                    else partial(CrossEntropyLoss, inplace_backward=True))
 
         self.bert = BertModel(config)
         self.cls = BertPreTrainingHeads(config)

+ 5 - 4
tests/losses/test_cross_entropy_apex.py → tests/losses/test_cross_entropy.py

@@ -6,7 +6,7 @@ import pytest
 
 from einops import rearrange
 
-from flass_attn.losses.cross_entropy_apex import CrossEntropyLossApex
+from flash_attn.losses.cross_entropy import CrossEntropyLossApex
 
 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
 
@@ -15,8 +15,9 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
 # @pytest.mark.parametrize('dtype', [torch.float16])
 @pytest.mark.parametrize('inplace_backward', [False, True])
 # @pytest.mark.parametrize('inplace_backward', [False])
+@pytest.mark.parametrize('smoothing', [0.0, 0.9])
 @pytest.mark.parametrize('vocab_size', [50257])
-def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
+def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
     device = 'cuda'
     rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
     # set seed
@@ -27,8 +28,8 @@ def test_cross_entropy_loss_apex(vocab_size, inplace_backward, dtype):
     x = x_pt.detach().clone().requires_grad_()
     y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
     y[torch.randperm(batch_size * seqlen)[:10]] = -100
-    model_pt = torch.nn.CrossEntropyLoss()
-    model = CrossEntropyLossApex(inplace_backward=inplace_backward)
+    model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
+    model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward)
     out = model(x, y)
     out_pt = model_pt(x_pt.float(), y)
     assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)

+ 9 - 6
tests/losses/test_cross_entropy_parallel.py

@@ -10,19 +10,21 @@ import pytest
 from apex.transformer import parallel_state
 from apex.transformer import tensor_parallel
 
-from flash_attn.losses.cross_entropy_parallel import CrossEntropyLossParallel
+from flash_attn.losses.cross_entropy import CrossEntropyLoss
 
 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
 
 
 @pytest.mark.parametrize('dtype', [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []))
-# @pytest.mark.parametrize('dtype', [torch.bfloat16])
+# @pytest.mark.parametrize('dtype', [torch.float16])
 @pytest.mark.parametrize('inplace_backward', [False, True])
 # @pytest.mark.parametrize('inplace_backward', [False])
+@pytest.mark.parametrize('smoothing', [0.0, 0.9])
+# @pytest.mark.parametrize('smoothing', [0.9])
 @pytest.mark.parametrize('vocab_size', [50264])
 @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
 # @pytest.mark.parametrize('world_size', [2])
-def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype):
+def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_backward, dtype):
     assert vocab_size % world_size == 0
     rtol, atol = ((1e-5, 1e-6) if dtype == torch.float32
                   else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)))
@@ -42,9 +44,10 @@ def test_cross_entropy_loss_apex(vocab_size, world_size, inplace_backward, dtype
     x = tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt).detach().clone().requires_grad_()
     y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
     y[torch.randperm(batch_size * seqlen)[:10]] = -100
-    model_pt = torch.nn.CrossEntropyLoss(reduction='none')
-    model = CrossEntropyLossParallel(reduction='none', inplace_backward=inplace_backward)
-    out = model(x, y)
+    model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none')
+    model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none',
+                                     inplace_backward=inplace_backward)
+    out = model(x, y, process_group=parallel_state.get_tensor_model_parallel_group())
     out_pt = model_pt(x_pt.float(), y)
     assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
 

+ 1 - 1
training/configs/experiment/owt/base.yaml

@@ -54,7 +54,7 @@ train:
   loss_fn:
     # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
     # It's also more numerically stable if we're using DeepSpeed 16 bits.
-    _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
+    _target_: src.losses.cross_entropy.CrossEntropyLoss
     inplace_backward: True  # to save memory
 
 eval:

+ 1 - 1
training/configs/experiment/pile/base.yaml

@@ -54,7 +54,7 @@ train:
   loss_fn:
     # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
     # It's also more numerically stable if we're using DeepSpeed 16 bits.
-    _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
+    _target_: src.losses.cross_entropy.CrossEntropyLoss
     inplace_backward: True  # to save memory
 
 eval:

+ 128 - 0
training/src/losses/cross_entropy.py

@@ -0,0 +1,128 @@
+# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
+# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
+# the losses we can get the global loss. There's no need to do it step by step
+# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
+# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
+import torch
+import torch.nn as nn
+
+import xentropy_cuda_lib
+
+# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
+# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
+# version of PyTorch. The following 2 lines are for backward compatibility with
+# older PyTorch.
+if "all_gather_into_tensor" not in dir(torch.distributed):
+    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
+
+
+class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
+                process_group=None):
+        """
+        logits: (batch, vocab_size)
+        labels: (batch,)
+        If process_group is not None, we're doing Tensor Parallel: each process is responsible for
+        one part of the vocab. The loss needs to be aggregated across processes.
+        """
+        batch, vocab_size = logits.shape
+        assert labels.shape == (batch,)
+        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
+        ctx.total_classes = world_size * vocab_size
+
+        if world_size == 1:
+            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
+            losses.masked_fill_(labels==ignored_index, 0)
+            labels_local = labels
+        else:
+            rank = torch.distributed.get_rank(process_group)
+            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
+
+            # Create a mask of valid vocab ids (1 means it needs to be masked).
+            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
+            ignored_mask = labels == ignored_index
+            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
+
+            # For tensor parallel cross entropy with smoothing, we want to pass in the total number
+            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
+            # last dimension of the input tensor.
+            losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
+                                                          world_size * vocab_size)
+            assert lse_local.shape == (batch,)
+            assert losses.shape == (batch,)
+            losses.masked_fill_(ignored_mask, 0)
+            # For labels == ignored_index, the loss is always 0.
+            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
+            # lse_local - predicted logit, and 0 otherwise.
+            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
+            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
+            # For labels not in the vocab of this partition, losses contains
+            # 0.1 * (lse_local - sum logit / total_classes).
+
+            lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
+                                        device=lse_local.device)
+            torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
+                                                     group=process_group)
+            handle_losses = torch.distributed.all_reduce(
+                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
+            )
+            lse = torch.logsumexp(lse_allgather, dim=0)
+            # If there's no smoothing, the total losses are lse_local - predicted_logit,
+            # we just have to subtract the lse_local and add the lse (global).
+            # If there's smoothing=0.1, the total losses are
+            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
+            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
+            rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
+            lse_local = lse_allgather[rank_per_sample,
+                                      torch.arange(batch, device=lse_allgather.device)]
+
+            handle_losses.wait()
+            if smoothing == 0.0:
+                losses += lse - lse_local
+            else:
+                losses += ((1 - smoothing) * (lse - lse_local)
+                           + smoothing * (lse - lse_allgather.sum(dim=0)))
+            losses.masked_fill_(ignored_mask, 0)
+
+        ctx.save_for_backward(logits, lse, labels_local)
+        ctx.smoothing = smoothing
+        ctx.ignored_index = ignored_index
+        ctx.inplace_backward = inplace_backward
+        return losses
+
+    @staticmethod
+    def backward(ctx, grad_loss):
+        logits, lse, labels = ctx.saved_tensors
+        grad_loss = grad_loss.contiguous()
+        grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
+        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
+                                                 ctx.smoothing, ctx.inplace_backward,
+                                                 ctx.total_classes)
+        return grad_logits, None, None, None, None, None, None
+
+
+class CrossEntropyLoss(nn.Module):
+
+    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
+                 inplace_backward=False):
+        super().__init__()
+        if reduction not in ['mean', 'none']:
+            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
+        self.ignore_index = ignore_index
+        self.reduction = reduction
+        self.label_smoothing = label_smoothing
+        self.inplace_backward = inplace_backward
+
+    def forward(self, input, target, process_group=None):
+        assert input.is_cuda and target.is_cuda
+        # SoftmaxCrossEntropyLoss implicitly casts to float
+        loss = SoftmaxCrossEntropyLossFn.apply(
+            input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
+            process_group
+        )
+        if self.reduction == 'mean':
+            return loss.sum() / (target != self.ignore_index).sum()
+        else:
+            return loss

+ 0 - 51
training/src/losses/cross_entropy_apex.py

@@ -1,51 +0,0 @@
-import torch
-import torch.nn as nn
-
-import xentropy_cuda_lib
-
-
-# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
-class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
-        losses, max_log_sum_exp = xentropy_cuda_lib.forward(
-            logits, labels, smoothing)
-        losses.masked_fill_(labels==padding_idx, 0)
-        ctx.save_for_backward(logits, max_log_sum_exp, labels)
-        ctx.smoothing = smoothing
-        ctx.padding_idx = padding_idx
-        ctx.inplace_backward = inplace_backward
-        return losses
-
-    @staticmethod
-    def backward(ctx, grad_loss):
-        logits, max_log_sum_exp, labels = ctx.saved_tensors
-        if not grad_loss.is_contiguous():
-            grad_loss = grad_loss.contiguous()
-        grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
-        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
-                                                 ctx.smoothing, ctx.inplace_backward)
-        return grad_logits, None, None, None, None
-
-
-class CrossEntropyLossApex(nn.Module):
-
-    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
-                 inplace_backward=False):
-        super().__init__()
-        if reduction not in ['mean', 'none']:
-            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
-        self.ignore_index = ignore_index
-        self.reduction = reduction
-        self.label_smoothing = label_smoothing
-        self.inplace_backward = inplace_backward
-
-    def forward(self, input, target):
-        assert input.is_cuda and target.is_cuda
-        # SoftmaxCrossEntropyLoss implicitly casts to float
-        loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
-                                               self.ignore_index, self.inplace_backward)
-        if self.reduction == 'mean':
-            return loss.sum() / (target != self.ignore_index).sum()
-        else:
-            return loss

+ 0 - 112
training/src/losses/cross_entropy_parallel.py

@@ -1,112 +0,0 @@
-# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
-# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
-# the losses we can get the global loss. There's no need to do it step by step
-# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
-import torch
-import torch.nn as nn
-
-import xentropy_cuda_lib
-
-from apex.transformer.parallel_state import get_tensor_model_parallel_group
-from apex.transformer.parallel_state import get_tensor_model_parallel_rank
-from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
-from apex.transformer.tensor_parallel.utils import VocabUtility
-
-# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
-# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
-# version of PyTorch. The following 4 lines are for backward comparability with
-# older PyTorch.
-if "all_gather_into_tensor" not in dir(torch.distributed):
-    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
-if "reduce_scatter_tensor" not in dir(torch.distributed):
-    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
-
-
-class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
-
-    @staticmethod
-    def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
-                inplace_backward=False):
-        """
-        logits_parallel: (batch, vocab_size / world_size)
-        labels: (batch,)
-        """
-        assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
-        batch, partition_vocab_size = logits_parallel.shape
-        assert labels.shape == (batch,)
-        rank = get_tensor_model_parallel_rank()
-        world_size = get_tensor_model_parallel_world_size()
-        vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
-            partition_vocab_size, get_tensor_model_parallel_rank(),
-            get_tensor_model_parallel_world_size()
-        )
-
-        # Create a mask of valid vocab ids (1 means it needs to be masked).
-        labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
-        ignored_mask = labels == ignored_index
-        labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
-        masked_labels = labels_local.clone()
-        masked_labels[labels_mask] = ignored_index
-
-        losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
-        assert lse_local.shape == (batch,)
-        assert losses.shape == (batch,)
-        losses.masked_fill_(masked_labels==ignored_index, 0)
-
-        if world_size > 1:
-            lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
-                                        device=lse_local.device)
-            torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
-                                                     group=get_tensor_model_parallel_group())
-            lse = torch.logsumexp(lse_allgather, dim=0)
-            torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM,
-                                         group=get_tensor_model_parallel_group())
-            # The losses are currently lse_local - predicted_logit, we just have to subtract the
-            # lse_local and add the lse (global).
-            rank_per_sample = labels // partition_vocab_size
-            lse_local = lse_allgather[rank_per_sample,
-                                      torch.arange(batch, device=lse_allgather.device)]
-            losses += lse - lse_local
-            losses.masked_fill_(ignored_mask, 0)
-        else:
-            lse = lse_local
-
-        ctx.save_for_backward(logits_parallel, lse, labels_local)
-        ctx.smoothing = smoothing
-        ctx.ignored_index = ignored_index
-        ctx.inplace_backward = inplace_backward
-        return losses
-
-    @staticmethod
-    def backward(ctx, grad_loss):
-        logits_parallel, lse, labels = ctx.saved_tensors
-        if not grad_loss.is_contiguous():
-            grad_loss = grad_loss.contiguous()
-        grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
-        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
-                                                 ctx.smoothing, ctx.inplace_backward)
-        return grad_logits, None, None, None, None, None
-
-
-class CrossEntropyLossParallel(nn.Module):
-
-    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
-                 inplace_backward=False):
-        super().__init__()
-        if reduction not in ['mean', 'none']:
-            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
-        self.ignore_index = ignore_index
-        self.reduction = reduction
-        self.label_smoothing = label_smoothing
-        self.inplace_backward = inplace_backward
-
-    def forward(self, input, target):
-        assert input.is_cuda and target.is_cuda
-        # SoftmaxCrossEntropyLoss implicitly casts to float
-        loss = SoftmaxCrossEntropyLossParallelFn.apply(
-            input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
-        )
-        if self.reduction == 'mean':
-            return loss.sum() / (target != self.ignore_index).sum()
-        else:
-            return loss

+ 1 - 1
training/src/metrics/perplexity.py

@@ -11,7 +11,7 @@ from torch import Tensor
 from torchmetrics import Metric
 
 try:
-    from src.losses.cross_entropy_apex import CrossEntropyLossApex as CrossEntropyLoss
+    from src.losses.cross_entropy import CrossEntropyLoss
 except ImportError:
     CrossEntropyLoss = torch.nn.CrossEntropyLoss