1
0
Эх сурвалжийг харах

[CrossEntropy] Use online softmax to simplify implementation

Tri Dao 7 сар өмнө
parent
commit
d79f9b41a8

+ 27 - 34
flash_attn/ops/triton/cross_entropy.py

@@ -34,7 +34,6 @@ def cross_entropy_fwd_kernel(
     total_classes,
     class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes
     n_cols,  # shapes
-    n_rows,
     logits_row_stride,  # strides
     BLOCK_SIZE: tl.constexpr,
     HAS_SMOOTHING: tl.constexpr,
@@ -42,26 +41,30 @@ def cross_entropy_fwd_kernel(
     SPLIT: tl.constexpr,
 ):
     row_idx = tl.program_id(0)
-    col_block_idx = tl.program_id(1)
     logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
-    col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    sum_logits = 0.0  # For smoothing
+    # Statistics for online softmax
+    m_i = -float("inf")
+    l_i = 0.0
+    for col_offset in range(0, n_cols, BLOCK_SIZE):
+        cols = col_offset + tl.arange(0, BLOCK_SIZE)
+        logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
+            tl.float32
+        ) * logit_scale
+        if HAS_SMOOTHING:
+            sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
+        m_i_new = tl.maximum(m_i, tl.max(logits))
+        l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
+        m_i = m_i_new
+    lse = tl.log(l_i) + m_i
+    tl.store(lse_ptr + row_idx, lse)
     label_idx = tl.load(labels_ptr + row_idx)
-    logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
-        tl.float32
-    ) * logit_scale
-    max_logits = tl.max(logits, 0)
-    if HAS_SMOOTHING:
-        sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
-    lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
-    tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
     if label_idx == ignore_index:
         loss = 0.0
         z_loss = 0.0
     else:
         label_idx -= class_start_idx
-        if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
-            n_cols, (col_block_idx + 1) * BLOCK_SIZE
-        ):
+        if label_idx >= 0 and label_idx < n_cols:
             logits_label = tl.load(logits_ptr + label_idx) * logit_scale
             if HAS_SMOOTHING:
                 loss = (
@@ -82,9 +85,9 @@ def cross_entropy_fwd_kernel(
             loss += z_loss
         else:
             z_loss = 0.0
-    tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
+    tl.store(loss_ptr + row_idx, loss)
     if not SPLIT:
-        tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
+        tl.store(z_loss_ptr + row_idx, z_loss)
 
 
 @triton.heuristics(
@@ -161,27 +164,20 @@ class CrossEntropyLoss(torch.autograd.Function):
 
         if logits.stride(-1) != 1:
             logits = logits.contiguous()
-        # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
-        MAX_BLOCK_SIZE = 64 * 1024
+        MAX_BLOCK_SIZE = 16 * 1024
         BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
         num_warps = (
             4
             if BLOCK_SIZE < 2048
             else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
         )
-        # We may split the lse computation across multiple blocks, then do a reduction
-        # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
-        # where having just one thread block processing more than 64k elements is slow.
-        split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
-        n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
-        loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
-        losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
-        lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
-        z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
+        losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
+        lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
+        z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
         # Need this, otherwise Triton tries to launch from cuda:0 and we get
         # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
         with torch.cuda.device(logits.device.index):
-            cross_entropy_fwd_kernel[(n_rows, n_splits)](
+            cross_entropy_fwd_kernel[(n_rows,)](
                 losses,  # data ptrs
                 lse,
                 z_losses,
@@ -194,23 +190,19 @@ class CrossEntropyLoss(torch.autograd.Function):
                 total_classes,
                 class_start_idx,
                 n_cols,  # shapes
-                n_rows,
                 logits.stride(0),  # strides
                 BLOCK_SIZE=BLOCK_SIZE,  # constants
                 num_warps=num_warps,
-                SPLIT=split,
+                SPLIT=world_size > 1,
             )
 
-        if split:
+        if world_size > 1:
             # If there's no smoothing, if labels are in the vocab of this partition, losses contains
             # - predicted logit, and 0 otherwise.
             # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
             # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
             # For labels not in the vocab of this partition, losses contains
             # -0.1 * sum logit / total_classes.
-            if n_splits > 1:
-                lse = torch.logsumexp(lse, dim=0)
-                losses = losses.sum(dim=0)
             if world_size > 1:
                 lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
                 torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
@@ -243,6 +235,7 @@ class CrossEntropyLoss(torch.autograd.Function):
         ctx.class_start_idx = class_start_idx
         ctx.inplace_backward = inplace_backward
 
+
         return losses, z_losses
 
     @staticmethod