123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- # Copyright (c) 2023, Tri Dao.
- from typing import Tuple, Optional, Union
- import torch
- import triton
- import triton.language as tl
- # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
- # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
- # version of PyTorch. The following 2 lines are for backward compatibility with
- # older PyTorch.
- if "all_gather_into_tensor" not in dir(torch.distributed):
- torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
- @triton.heuristics(
- {
- "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
- }
- )
- @triton.jit
- def cross_entropy_fwd_kernel(
- loss_ptr, # data ptrs
- lse_ptr,
- z_loss_ptr,
- logits_ptr,
- labels_ptr,
- smoothing,
- logit_scale,
- lse_square_scale,
- ignore_index,
- total_classes,
- class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
- n_cols, # shapes
- n_rows,
- logits_row_stride, # strides
- BLOCK_SIZE: tl.constexpr,
- HAS_SMOOTHING: tl.constexpr,
- # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
- SPLIT: tl.constexpr,
- ):
- row_idx = tl.program_id(0)
- col_block_idx = tl.program_id(1)
- logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
- col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- label_idx = tl.load(labels_ptr + row_idx)
- logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
- tl.float32
- ) * logit_scale
- max_logits = tl.max(logits, 0)
- if HAS_SMOOTHING:
- sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
- lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
- tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
- if label_idx == ignore_index:
- loss = 0.0
- z_loss = 0.0
- else:
- label_idx -= class_start_idx
- if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
- n_cols, (col_block_idx + 1) * BLOCK_SIZE
- ):
- logits_label = tl.load(logits_ptr + label_idx) * logit_scale
- if HAS_SMOOTHING:
- loss = (
- (lse if not SPLIT else 0.0)
- - smoothing * sum_logits / total_classes
- - (1 - smoothing) * logits_label
- )
- else:
- loss = (lse if not SPLIT else 0.0) - logits_label
- else:
- # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
- if HAS_SMOOTHING:
- loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
- else:
- loss = 0.0
- if not SPLIT:
- z_loss = lse_square_scale * lse * lse
- loss += z_loss
- else:
- z_loss = 0.0
- tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
- if not SPLIT:
- tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
- @triton.heuristics(
- {
- "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
- }
- )
- @triton.jit
- def cross_entropy_bwd_kernel(
- dlogits_ptr, # data ptrs
- dloss_ptr,
- logits_ptr,
- lse_ptr,
- labels_ptr,
- smoothing,
- logit_scale,
- lse_square_scale,
- ignore_index,
- total_classes,
- class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
- n_cols, # shapes
- logits_row_stride, # strides
- dlogits_row_stride,
- dloss_row_stride,
- BLOCK_SIZE: tl.constexpr,
- HAS_SMOOTHING: tl.constexpr,
- ):
- row_idx = tl.program_id(0)
- col_block_idx = tl.program_id(1)
- logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
- dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
- col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- label_idx = tl.load(labels_ptr + row_idx)
- if label_idx != ignore_index:
- dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
- else:
- dloss = 0.0
- logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
- tl.float32
- ) * logit_scale
- lse = tl.load(lse_ptr + row_idx)
- probs = tl.exp(logits - lse)
- probs += 2.0 * lse_square_scale * lse * probs
- label_idx -= class_start_idx
- if HAS_SMOOTHING:
- smooth_positive = 1.0 - smoothing
- smooth_negative = smoothing / total_classes
- probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
- else:
- probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
- tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
- class CrossEntropyLoss(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- logits,
- labels,
- smoothing=0.0,
- logit_scale=1.0,
- lse_square_scale=0.0,
- ignore_index=-100,
- inplace_backward=False,
- process_group=None,
- ):
- n_rows, n_cols = logits.shape
- assert labels.shape == (n_rows,)
- world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
- total_classes = world_size * n_cols
- rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
- class_start_idx = rank * n_cols
- if logits.stride(-1) != 1:
- logits = logits.contiguous()
- # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
- MAX_BLOCK_SIZE = 64 * 1024
- BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
- num_warps = (
- 4
- if BLOCK_SIZE < 2048
- else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
- )
- # We may split the lse computation across multiple blocks, then do a reduction
- # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
- # where having just one thread block processing more than 64k elements is slow.
- split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
- n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
- loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
- losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
- lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
- z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
- with torch.cuda.device(logits.device.index):
- cross_entropy_fwd_kernel[(n_rows, n_splits)](
- losses, # data ptrs
- lse,
- z_losses,
- logits,
- labels,
- smoothing,
- logit_scale,
- lse_square_scale,
- ignore_index,
- total_classes,
- class_start_idx,
- n_cols, # shapes
- n_rows,
- logits.stride(0), # strides
- BLOCK_SIZE=BLOCK_SIZE, # constants
- num_warps=num_warps,
- SPLIT=split,
- )
- if split:
- # If there's no smoothing, if labels are in the vocab of this partition, losses contains
- # - predicted logit, and 0 otherwise.
- # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
- # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
- # For labels not in the vocab of this partition, losses contains
- # -0.1 * sum logit / total_classes.
- if n_splits > 1:
- lse = torch.logsumexp(lse, dim=0)
- losses = losses.sum(dim=0)
- if world_size > 1:
- lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
- torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
- handle_losses = torch.distributed.all_reduce(
- losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
- )
- lse = torch.logsumexp(lse_allgather, dim=0)
- handle_losses.wait()
- # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
- # we just have to add the (global) lse.
- # If there's smoothing=0.1, the total losses are
- # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
- # Again, we just have to add the (global) lse.
- losses += lse
- if lse_square_scale != 0.0:
- z_losses = lse_square_scale * lse.square()
- z_losses.masked_fill_(labels == ignore_index, 0.0)
- losses += z_losses
- else:
- z_losses = torch.zeros_like(losses)
- losses.masked_fill_(labels == ignore_index, 0.0)
- ctx.save_for_backward(logits, lse, labels)
- ctx.mark_non_differentiable(z_losses)
- ctx.smoothing = smoothing
- ctx.logit_scale = logit_scale
- ctx.lse_square_scale = lse_square_scale
- ctx.ignore_index = ignore_index
- ctx.total_classes = total_classes
- ctx.class_start_idx = class_start_idx
- ctx.inplace_backward = inplace_backward
- return losses, z_losses
- @staticmethod
- def backward(ctx, grad_losses, grad_z_losses):
- del grad_z_losses # z_losses are only for logging.
- logits, lse, labels = ctx.saved_tensors
- dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
- n_rows, n_cols = logits.shape
- BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
- num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
- grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
- with torch.cuda.device(logits.device.index):
- cross_entropy_bwd_kernel[grid](
- dlogits, # data ptrs
- grad_losses,
- logits,
- lse,
- labels,
- ctx.smoothing,
- ctx.logit_scale,
- ctx.lse_square_scale,
- ctx.ignore_index,
- ctx.total_classes,
- ctx.class_start_idx,
- n_cols, # shapes
- logits.stride(0), # strides
- dlogits.stride(0),
- grad_losses.stride(0),
- BLOCK_SIZE=BLOCK_SIZE, # constants
- num_warps=num_warps,
- )
- return dlogits, None, None, None, None, None, None, None, None
- def cross_entropy_loss(
- logits: torch.Tensor,
- labels: torch.Tensor,
- label_smoothing: float = 0.0,
- logit_scale: float = 1.0,
- lse_square_scale: float = 0.0,
- ignore_index=-100,
- inplace_backward: bool = False,
- process_group=None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Arguments:
- logits: (batch, vocab_size)
- labels: (batch,)
- label_smoothing: float
- logit_scale: float. Multiply logits by this scale before calculating the loss.
- lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
- This is also referred to as "z-loss".
- ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
- inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
- This saves memory.
- process_group: if not None, we're doing Tensor Parallel: each process is responsible for
- one part of the vocab. The loss will be aggregated across processes.
- Returns:
- losses: (batch,), float
- z_losses: (batch,), float
- """
- return CrossEntropyLoss.apply(
- logits,
- labels,
- label_smoothing,
- logit_scale,
- lse_square_scale,
- ignore_index,
- inplace_backward,
- process_group,
- )
|