cross_entropy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Tuple, Optional, Union
  3. import torch
  4. from einops import rearrange
  5. import triton
  6. import triton.language as tl
  7. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  8. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  9. # version of PyTorch. The following 2 lines are for backward compatibility with
  10. # older PyTorch.
  11. if "all_gather_into_tensor" not in dir(torch.distributed):
  12. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  13. @triton.heuristics(
  14. {
  15. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  16. }
  17. )
  18. @triton.jit
  19. def cross_entropy_fwd_kernel(
  20. loss_ptr, # data ptrs
  21. lse_ptr,
  22. logits_ptr,
  23. labels_ptr,
  24. smoothing,
  25. lse_square_scale,
  26. ignored_index,
  27. total_classes,
  28. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  29. n_cols, # shapes
  30. n_rows,
  31. logits_row_stride, # strides
  32. BLOCK_SIZE: tl.constexpr,
  33. HAS_SMOOTHING: tl.constexpr,
  34. # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
  35. SPLIT: tl.constexpr,
  36. ):
  37. row_idx = tl.program_id(0)
  38. col_block_idx = tl.program_id(1)
  39. logits_ptr = logits_ptr + row_idx * logits_row_stride
  40. col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  41. label_idx = tl.load(labels_ptr + row_idx)
  42. logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
  43. tl.float32
  44. )
  45. max_logits = tl.max(logits, 0)
  46. if HAS_SMOOTHING:
  47. sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
  48. lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
  49. tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
  50. if label_idx == ignored_index:
  51. loss = 0.0
  52. else:
  53. label_idx -= class_start_idx
  54. if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
  55. n_cols, (col_block_idx + 1) * BLOCK_SIZE
  56. ):
  57. logits_label = tl.load(logits_ptr + label_idx)
  58. if HAS_SMOOTHING:
  59. loss = (
  60. (lse if not SPLIT else 0.0)
  61. - smoothing * sum_logits / total_classes
  62. - (1 - smoothing) * logits_label
  63. )
  64. else:
  65. loss = (lse if not SPLIT else 0.0) - logits_label
  66. else:
  67. # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
  68. if HAS_SMOOTHING:
  69. loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
  70. else:
  71. loss = 0.0
  72. if not SPLIT:
  73. loss += lse_square_scale * lse * lse
  74. tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
  75. @triton.heuristics(
  76. {
  77. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  78. }
  79. )
  80. @triton.jit
  81. def cross_entropy_bwd_kernel(
  82. dlogits_ptr, # data ptrs
  83. dloss_ptr,
  84. logits_ptr,
  85. lse_ptr,
  86. labels_ptr,
  87. smoothing,
  88. lse_square_scale,
  89. ignored_index,
  90. total_classes,
  91. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  92. n_cols, # shapes
  93. logits_row_stride, # strides
  94. dlogits_row_stride,
  95. dloss_row_stride,
  96. BLOCK_SIZE: tl.constexpr,
  97. HAS_SMOOTHING: tl.constexpr,
  98. ):
  99. row_idx = tl.program_id(0)
  100. col_block_idx = tl.program_id(1)
  101. logits_ptr = logits_ptr + row_idx * logits_row_stride
  102. dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride
  103. col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  104. label_idx = tl.load(labels_ptr + row_idx)
  105. if label_idx != ignored_index:
  106. dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
  107. else:
  108. dloss = 0.0
  109. logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
  110. tl.float32
  111. )
  112. lse = tl.load(lse_ptr + row_idx)
  113. probs = tl.exp(logits - lse)
  114. probs += 2.0 * lse_square_scale * lse * probs
  115. label_idx -= class_start_idx
  116. if HAS_SMOOTHING:
  117. smooth_positive = 1.0 - smoothing
  118. smooth_negative = smoothing / total_classes
  119. probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
  120. else:
  121. probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
  122. tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)
  123. class CrossEntropyLoss(torch.autograd.Function):
  124. @staticmethod
  125. def forward(
  126. ctx,
  127. logits,
  128. labels,
  129. smoothing,
  130. lse_square_scale=0.0,
  131. ignored_index=-100,
  132. inplace_backward=False,
  133. process_group=None,
  134. ):
  135. n_rows, n_cols = logits.shape
  136. assert labels.shape == (n_rows,)
  137. world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
  138. total_classes = world_size * n_cols
  139. rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
  140. class_start_idx = rank * n_cols
  141. if logits.stride(-1) != 1:
  142. logits = logits.contiguous()
  143. # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
  144. MAX_BLOCK_SIZE = 64 * 1024
  145. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
  146. num_warps = (
  147. 4
  148. if BLOCK_SIZE < 2048
  149. else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
  150. )
  151. # We may split the lse computation across multiple blocks, then do a reduction
  152. # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
  153. # where having just one thread block processing more than 64k elements is slow.
  154. split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
  155. n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
  156. loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
  157. losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
  158. lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
  159. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  160. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  161. with torch.cuda.device(logits.device.index):
  162. cross_entropy_fwd_kernel[(n_rows, n_splits)](
  163. losses, # data ptrs
  164. lse,
  165. logits,
  166. labels,
  167. smoothing,
  168. lse_square_scale,
  169. ignored_index,
  170. total_classes,
  171. class_start_idx,
  172. n_cols, # shapes
  173. n_rows,
  174. logits.stride(0), # strides
  175. BLOCK_SIZE=BLOCK_SIZE, # constants
  176. num_warps=num_warps,
  177. SPLIT=split,
  178. )
  179. if split:
  180. # If there's no smoothing, if labels are in the vocab of this partition, losses contains
  181. # - predicted logit, and 0 otherwise.
  182. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
  183. # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
  184. # For labels not in the vocab of this partition, losses contains
  185. # -0.1 * sum logit / total_classes.
  186. if world_size > 1:
  187. lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
  188. torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
  189. handle_losses = torch.distributed.all_reduce(
  190. losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
  191. )
  192. lse = torch.logsumexp(lse_allgather, dim=0)
  193. handle_losses.wait()
  194. else:
  195. lse = torch.logsumexp(lse, dim=0)
  196. losses = losses.sum(dim=0)
  197. # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
  198. # we just have to add the (global) lse.
  199. # If there's smoothing=0.1, the total losses are
  200. # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
  201. # Again, we just have to add the (global) lse.
  202. losses += lse
  203. if lse_square_scale != 0.0:
  204. losses += lse_square_scale * lse.square()
  205. losses.masked_fill_(labels == ignored_index, 0.0)
  206. ctx.save_for_backward(logits, lse, labels)
  207. ctx.smoothing = smoothing
  208. ctx.lse_square_scale = lse_square_scale
  209. ctx.ignored_index = ignored_index
  210. ctx.total_classes = total_classes
  211. ctx.class_start_idx = class_start_idx
  212. ctx.inplace_backward = inplace_backward
  213. return losses
  214. @staticmethod
  215. def backward(ctx, grad_losses):
  216. logits, lse, labels = ctx.saved_tensors
  217. dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
  218. n_rows, n_cols = logits.shape
  219. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
  220. num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
  221. grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
  222. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  223. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  224. with torch.cuda.device(logits.device.index):
  225. cross_entropy_bwd_kernel[grid](
  226. dlogits, # data ptrs
  227. grad_losses,
  228. logits,
  229. lse,
  230. labels,
  231. ctx.smoothing,
  232. ctx.lse_square_scale,
  233. ctx.ignored_index,
  234. ctx.total_classes,
  235. ctx.class_start_idx,
  236. n_cols, # shapes
  237. logits.stride(0), # strides
  238. dlogits.stride(0),
  239. grad_losses.stride(0),
  240. BLOCK_SIZE=BLOCK_SIZE, # constants
  241. num_warps=num_warps,
  242. )
  243. return dlogits, None, None, None, None, None, None, None
  244. def cross_entropy_loss(
  245. logits: torch.Tensor,
  246. labels: torch.Tensor,
  247. label_smoothing: float = 0.0,
  248. lse_square_scale: float = 0.0,
  249. ignored_index=-100,
  250. inplace_backward: bool = False,
  251. process_group=None,
  252. ) -> Tuple[torch.Tensor, torch.Tensor]:
  253. """
  254. Arguments:
  255. logits: (batch, vocab_size)
  256. labels: (batch,)
  257. label_smoothing: float
  258. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
  259. This is also referred to as "z-loss".
  260. ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
  261. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
  262. This saves memory.
  263. process_group: if not None, we're doing Tensor Parallel: each process is responsible for
  264. one part of the vocab. The loss will be aggregated across processes.
  265. Returns:
  266. losses: (batch,), float
  267. """
  268. return CrossEntropyLoss.apply(
  269. logits,
  270. labels,
  271. label_smoothing,
  272. lse_square_scale,
  273. ignored_index,
  274. inplace_backward,
  275. process_group,
  276. )