Browse Source

[CE] Implement CrossEntropyLoss in Triton

Tri Dao 1 year ago
parent
commit
5400fdc4ac

+ 5 - 0
csrc/xentropy/README.md

@@ -7,3 +7,8 @@ It has only been tested on A100s.
 ```sh
 cd csrc/xentropy && pip install .
 ```
+
+As of 2023-09-15, this extension is no longer used in the FlashAttention repo.
+We've instead switched to a Triton-based
+[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py). 
+See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details.

+ 34 - 119
flash_attn/losses/cross_entropy.py

@@ -1,116 +1,9 @@
-# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
-# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
-# the losses we can get the global loss. There's no need to do it step by step
-# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
-# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
+# Copyright (c) 2023, Tri Dao.
+
 import torch
 import torch.nn as nn
-import xentropy_cuda_lib
-
-# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
-# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
-# version of PyTorch. The following 2 lines are for backward compatibility with
-# older PyTorch.
-if "all_gather_into_tensor" not in dir(torch.distributed):
-    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
-
-
-class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
-    @staticmethod
-    def forward(
-        ctx,
-        logits,
-        labels,
-        smoothing=0.0,
-        ignored_index=-100,
-        inplace_backward=False,
-        process_group=None,
-    ):
-        """
-        logits: (batch, vocab_size)
-        labels: (batch,)
-        If process_group is not None, we're doing Tensor Parallel: each process is responsible for
-        one part of the vocab. The loss needs to be aggregated across processes.
-        """
-        batch, vocab_size = logits.shape
-        assert labels.shape == (batch,)
-        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
-        ctx.total_classes = world_size * vocab_size
-
-        if world_size == 1:
-            losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
-            losses.masked_fill_(labels == ignored_index, 0)
-            labels_local = labels
-        else:
-            rank = torch.distributed.get_rank(process_group)
-            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
 
-            # Create a mask of valid vocab ids (1 means it needs to be masked).
-            labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
-            ignored_mask = labels == ignored_index
-            labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
-
-            # For tensor parallel cross entropy with smoothing, we want to pass in the total number
-            # of classes so that smoothing can be applied correctly. If total_classes=-1, use the
-            # last dimension of the input tensor.
-            losses, lse_local = xentropy_cuda_lib.forward(
-                logits, labels_local, smoothing, world_size * vocab_size
-            )
-            assert lse_local.shape == (batch,)
-            assert losses.shape == (batch,)
-            losses.masked_fill_(ignored_mask, 0)
-            # For labels == ignored_index, the loss is always 0.
-            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
-            # lse_local - predicted logit, and 0 otherwise.
-            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
-            # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
-            # For labels not in the vocab of this partition, losses contains
-            # 0.1 * (lse_local - sum logit / total_classes).
-
-            lse_allgather = torch.empty(
-                world_size, batch, dtype=lse_local.dtype, device=lse_local.device
-            )
-            torch.distributed.all_gather_into_tensor(
-                lse_allgather, lse_local.contiguous(), group=process_group
-            )
-            handle_losses = torch.distributed.all_reduce(
-                losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
-            )
-            lse = torch.logsumexp(lse_allgather, dim=0)
-            # If there's no smoothing, the total losses are lse_local - predicted_logit,
-            # we just have to subtract the lse_local and add the lse (global).
-            # If there's smoothing=0.1, the total losses are
-            # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
-            # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
-            rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
-            lse_local = lse_allgather[
-                rank_per_sample, torch.arange(batch, device=lse_allgather.device)
-            ]
-
-            handle_losses.wait()
-            if smoothing == 0.0:
-                losses += lse - lse_local
-            else:
-                losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
-                    lse - lse_allgather.sum(dim=0)
-                )
-            losses.masked_fill_(ignored_mask, 0)
-
-        ctx.save_for_backward(logits, lse, labels_local)
-        ctx.smoothing = smoothing
-        ctx.ignored_index = ignored_index
-        ctx.inplace_backward = inplace_backward
-        return losses
-
-    @staticmethod
-    def backward(ctx, grad_loss):
-        logits, lse, labels = ctx.saved_tensors
-        grad_loss = grad_loss.contiguous()
-        grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
-        grad_logits = xentropy_cuda_lib.backward(
-            grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
-        )
-        return grad_logits, None, None, None, None, None, None
+from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
 
 
 class CrossEntropyLoss(nn.Module):
@@ -119,30 +12,52 @@ class CrossEntropyLoss(nn.Module):
         ignore_index=-100,
         reduction="mean",
         label_smoothing=0.0,
+        lse_square_scale=0.0,
         inplace_backward=False,
         process_group=None,
     ):
+        """
+        Arguments:
+            ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
+            label_smoothing: float
+            lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
+                This is also referred to as "z-loss".
+            inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
+                This saves memory.
+            process_group: if not None, we're doing Tensor Parallel: each process is responsible for
+            one part of the vocab. The loss will be aggregated across processes.
+        """
         super().__init__()
-        if reduction not in ["mean", "none"]:
-            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
+        if reduction not in ["mean", "none", "sum"]:
+            raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
         self.ignore_index = ignore_index
         self.reduction = reduction
         self.label_smoothing = label_smoothing
+        self.lse_square_scale = lse_square_scale
         self.inplace_backward = inplace_backward
         self.process_group = process_group
 
     def forward(self, input, target):
-        assert input.is_cuda and target.is_cuda
-        # SoftmaxCrossEntropyLoss implicitly casts to float
-        loss = SoftmaxCrossEntropyLossFn.apply(
+        """
+        Arguments:
+            input: (batch, vocab_size)
+            target: (batch,)
+        Returns:
+            losses: (batch,) if reduction is 'none', else (1,), dtype float
+        """
+        assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
+        loss = cross_entropy_loss(
             input,
             target,
-            self.label_smoothing,
-            self.ignore_index,
-            self.inplace_backward,
-            self.process_group,
+            label_smoothing=self.label_smoothing,
+            lse_square_scale=self.lse_square_scale,
+            ignored_index=self.ignore_index,
+            inplace_backward=self.inplace_backward,
+            process_group=self.process_group,
         )
         if self.reduction == "mean":
             return loss.sum() / (target != self.ignore_index).sum()
+        elif self.reduction == "sum":
+            return loss.sum()
         else:
             return loss

+ 293 - 0
flash_attn/ops/triton/cross_entropy.py

@@ -0,0 +1,293 @@
+# Copyright (c) 2023, Tri Dao.
+
+from typing import Tuple, Optional, Union
+
+import torch
+
+from einops import rearrange
+
+import triton
+import triton.language as tl
+
+# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
+# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
+# version of PyTorch. The following 2 lines are for backward compatibility with
+# older PyTorch.
+if "all_gather_into_tensor" not in dir(torch.distributed):
+    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
+
+
+@triton.heuristics(
+    {
+        "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
+    }
+)
+@triton.jit
+def cross_entropy_fwd_kernel(
+    loss_ptr,  # data ptrs
+    lse_ptr,
+    logits_ptr,
+    labels_ptr,
+    smoothing,
+    lse_square_scale,
+    ignored_index,
+    total_classes,
+    class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes
+    n_cols,  # shapes
+    n_rows,
+    logits_row_stride,  # strides
+    BLOCK_SIZE: tl.constexpr,
+    HAS_SMOOTHING: tl.constexpr,
+    # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
+    SPLIT: tl.constexpr,
+):
+    row_idx = tl.program_id(0)
+    col_block_idx = tl.program_id(1)
+    logits_ptr = logits_ptr + row_idx * logits_row_stride
+    col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    label_idx = tl.load(labels_ptr + row_idx)
+    logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
+        tl.float32
+    )
+    max_logits = tl.max(logits, 0)
+    if HAS_SMOOTHING:
+        sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
+    lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
+    tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
+    if label_idx == ignored_index:
+        loss = 0.0
+    else:
+        label_idx -= class_start_idx
+        if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
+            n_cols, (col_block_idx + 1) * BLOCK_SIZE
+        ):
+            logits_label = tl.load(logits_ptr + label_idx)
+            if HAS_SMOOTHING:
+                loss = (
+                    (lse if not SPLIT else 0.0)
+                    - smoothing * sum_logits / total_classes
+                    - (1 - smoothing) * logits_label
+                )
+            else:
+                loss = (lse if not SPLIT else 0.0) - logits_label
+        else:
+            # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
+            if HAS_SMOOTHING:
+                loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
+            else:
+                loss = 0.0
+        if not SPLIT:
+            loss += lse_square_scale * lse * lse
+    tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
+
+
+@triton.heuristics(
+    {
+        "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
+    }
+)
+@triton.jit
+def cross_entropy_bwd_kernel(
+    dlogits_ptr,  # data ptrs
+    dloss_ptr,
+    logits_ptr,
+    lse_ptr,
+    labels_ptr,
+    smoothing,
+    lse_square_scale,
+    ignored_index,
+    total_classes,
+    class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes
+    n_cols,  # shapes
+    logits_row_stride,  # strides
+    dlogits_row_stride,
+    dloss_row_stride,
+    BLOCK_SIZE: tl.constexpr,
+    HAS_SMOOTHING: tl.constexpr,
+):
+    row_idx = tl.program_id(0)
+    col_block_idx = tl.program_id(1)
+    logits_ptr = logits_ptr + row_idx * logits_row_stride
+    dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride
+    col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    label_idx = tl.load(labels_ptr + row_idx)
+    if label_idx != ignored_index:
+        dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
+    else:
+        dloss = 0.0
+    logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
+        tl.float32
+    )
+    lse = tl.load(lse_ptr + row_idx)
+    probs = tl.exp(logits - lse)
+    probs += 2.0 * lse_square_scale * lse * probs
+    label_idx -= class_start_idx
+    if HAS_SMOOTHING:
+        smooth_positive = 1.0 - smoothing
+        smooth_negative = smoothing / total_classes
+        probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
+    else:
+        probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
+    tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)
+
+
+class CrossEntropyLoss(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        logits,
+        labels,
+        smoothing,
+        lse_square_scale=0.0,
+        ignored_index=-100,
+        inplace_backward=False,
+        process_group=None,
+    ):
+        n_rows, n_cols = logits.shape
+        assert labels.shape == (n_rows,)
+        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
+        total_classes = world_size * n_cols
+        rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
+        class_start_idx = rank * n_cols
+
+        if logits.stride(-1) != 1:
+            logits = logits.contiguous()
+        # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
+        MAX_BLOCK_SIZE = 64 * 1024
+        BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
+        num_warps = (
+            4
+            if BLOCK_SIZE < 2048
+            else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
+        )
+        # We may split the lse computation across multiple blocks, then do a reduction
+        # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
+        # where having just one thread block processing more than 64k elements is slow.
+        split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
+        n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
+        loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
+        losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
+        lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
+        # Need this, otherwise Triton tries to launch from cuda:0 and we get
+        # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+        with torch.cuda.device(logits.device.index):
+            cross_entropy_fwd_kernel[(n_rows, n_splits)](
+                losses,  # data ptrs
+                lse,
+                logits,
+                labels,
+                smoothing,
+                lse_square_scale,
+                ignored_index,
+                total_classes,
+                class_start_idx,
+                n_cols,  # shapes
+                n_rows,
+                logits.stride(0),  # strides
+                BLOCK_SIZE=BLOCK_SIZE,  # constants
+                num_warps=num_warps,
+                SPLIT=split,
+            )
+
+        if split:
+            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
+            # - predicted logit, and 0 otherwise.
+            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
+            # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
+            # For labels not in the vocab of this partition, losses contains
+            # -0.1 * sum logit / total_classes.
+            if world_size > 1:
+                lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
+                torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
+                handle_losses = torch.distributed.all_reduce(
+                    losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
+                )
+                lse = torch.logsumexp(lse_allgather, dim=0)
+                handle_losses.wait()
+            else:
+                lse = torch.logsumexp(lse, dim=0)
+                losses = losses.sum(dim=0)
+            # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
+            # we just have to add the (global) lse.
+            # If there's smoothing=0.1, the total losses are
+            # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
+            # Again, we just have to add the (global) lse.
+            losses += lse
+            if lse_square_scale != 0.0:
+                losses += lse_square_scale * lse.square()
+            losses.masked_fill_(labels == ignored_index, 0.0)
+
+        ctx.save_for_backward(logits, lse, labels)
+        ctx.smoothing = smoothing
+        ctx.lse_square_scale = lse_square_scale
+        ctx.ignored_index = ignored_index
+        ctx.total_classes = total_classes
+        ctx.class_start_idx = class_start_idx
+        ctx.inplace_backward = inplace_backward
+        return losses
+
+    @staticmethod
+    def backward(ctx, grad_losses):
+        logits, lse, labels = ctx.saved_tensors
+        dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
+        n_rows, n_cols = logits.shape
+        BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
+        num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
+        grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"]))  # noqa
+        # Need this, otherwise Triton tries to launch from cuda:0 and we get
+        # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+        with torch.cuda.device(logits.device.index):
+            cross_entropy_bwd_kernel[grid](
+                dlogits,  # data ptrs
+                grad_losses,
+                logits,
+                lse,
+                labels,
+                ctx.smoothing,
+                ctx.lse_square_scale,
+                ctx.ignored_index,
+                ctx.total_classes,
+                ctx.class_start_idx,
+                n_cols,  # shapes
+                logits.stride(0),  # strides
+                dlogits.stride(0),
+                grad_losses.stride(0),
+                BLOCK_SIZE=BLOCK_SIZE,  # constants
+                num_warps=num_warps,
+            )
+        return dlogits, None, None, None, None, None, None, None
+
+
+def cross_entropy_loss(
+    logits: torch.Tensor,
+    labels: torch.Tensor,
+    label_smoothing: float = 0.0,
+    lse_square_scale: float = 0.0,
+    ignored_index=-100,
+    inplace_backward: bool = False,
+    process_group=None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Arguments:
+        logits: (batch, vocab_size)
+        labels: (batch,)
+        label_smoothing: float
+        lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
+            This is also referred to as "z-loss".
+        ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
+        inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
+            This saves memory.
+        process_group: if not None, we're doing Tensor Parallel: each process is responsible for
+        one part of the vocab. The loss will be aggregated across processes.
+    Returns:
+        losses: (batch,), float
+    """
+    return CrossEntropyLoss.apply(
+        logits,
+        labels,
+        label_smoothing,
+        lse_square_scale,
+        ignored_index,
+        inplace_backward,
+        process_group,
+    )

+ 20 - 8
tests/losses/test_cross_entropy.py

@@ -4,7 +4,7 @@ import pytest
 import torch
 import torch.nn.functional as F
 from einops import rearrange
-from flash_attn.losses.cross_entropy import CrossEntropyLossApex
+from flash_attn.losses.cross_entropy import CrossEntropyLoss
 
 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
 
@@ -12,12 +12,16 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
 @pytest.mark.parametrize(
     "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
 )
-# @pytest.mark.parametrize('dtype', [torch.float16])
+# @pytest.mark.parametrize("dtype", [torch.float16])
 @pytest.mark.parametrize("inplace_backward", [False, True])
-# @pytest.mark.parametrize('inplace_backward', [False])
+# @pytest.mark.parametrize("inplace_backward", [False])
+@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
+# @pytest.mark.parametrize("lse_square_scale", [1e-2])
 @pytest.mark.parametrize("smoothing", [0.0, 0.9])
-@pytest.mark.parametrize("vocab_size", [50257])
-def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype):
+# @pytest.mark.parametrize("smoothing", [0.0])
+@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024])  # test vocab larger than 64k for split
+# @pytest.mark.parametrize("vocab_size", [12])
+def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype):
     device = "cuda"
     rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
     # set seed
@@ -29,12 +33,20 @@ def test_cross_entropy_loss_apex(vocab_size, smoothing, inplace_backward, dtype)
     )
     x = x_pt.detach().clone().requires_grad_()
     y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
-    y[torch.randperm(batch_size * seqlen)[:10]] = -100
+    if batch_size * seqlen > 10:
+        y[torch.randperm(batch_size * seqlen)[:10]] = -100
     model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
-    model = CrossEntropyLossApex(label_smoothing=smoothing, inplace_backward=inplace_backward)
+    model = CrossEntropyLoss(
+        label_smoothing=smoothing,
+        lse_square_scale=lse_square_scale,
+        inplace_backward=inplace_backward,
+    )
     out = model(x, y)
     out_pt = model_pt(x_pt.float(), y)
-    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
+    if lse_square_scale > 0.0:
+        lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
+        out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
+    assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
 
     g = torch.randn_like(out)
     out_pt.backward(g)

+ 18 - 8
tests/losses/test_cross_entropy_parallel.py

@@ -1,5 +1,5 @@
 # Run test with:
-# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/losses/test_cross_entropy_parallel.py
+# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/losses/test_cross_entropy_parallel.py
 
 import math
 
@@ -15,15 +15,20 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
 @pytest.mark.parametrize(
     "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
 )
-# @pytest.mark.parametrize('dtype', [torch.float16])
+# @pytest.mark.parametrize("dtype", [torch.float16])
 @pytest.mark.parametrize("inplace_backward", [False, True])
-# @pytest.mark.parametrize('inplace_backward', [False])
+# @pytest.mark.parametrize("inplace_backward", [False])
+@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
+# @pytest.mark.parametrize("lse_square_scale", [1e-2])
 @pytest.mark.parametrize("smoothing", [0.0, 0.9])
-# @pytest.mark.parametrize('smoothing', [0.9])
-@pytest.mark.parametrize("vocab_size", [50264])
-@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
-# @pytest.mark.parametrize('world_size', [2])
-def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_backward, dtype):
+# @pytest.mark.parametrize("smoothing", [0.0])
+@pytest.mark.parametrize("vocab_size", [50264, 128 * 1024])  # test vocab larger than 64k for split
+# @pytest.mark.parametrize("vocab_size", [50264])  # test vocab larger than 64k for split
+@pytest.mark.parametrize("world_size", [1, 2, 4])
+# @pytest.mark.parametrize("world_size", [2])
+def test_cross_entropy_loss_parallel(
+    vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype
+):
     assert vocab_size % world_size == 0
     rtol, atol = (
         (1e-5, 1e-6)
@@ -56,11 +61,16 @@ def test_cross_entropy_loss_parallel(vocab_size, world_size, smoothing, inplace_
     model = CrossEntropyLoss(
         label_smoothing=smoothing,
         reduction="none",
+        lse_square_scale=lse_square_scale,
         inplace_backward=inplace_backward,
         process_group=parallel_state.get_tensor_model_parallel_group(),
     )
     out = model(x, y)
     out_pt = model_pt(x_pt.float(), y)
+    if lse_square_scale > 0.0:
+        lse_pt = torch.logsumexp(x_pt.float(), dim=-1)
+        out_pt += lse_square_scale * lse_pt.square()
+        out_pt.masked_fill_(y == -100, 0.0)
     assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
 
     g = torch.randn_like(out)