cross_entropy.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Tuple, Optional, Union
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  7. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  8. # version of PyTorch. The following 2 lines are for backward compatibility with
  9. # older PyTorch.
  10. if "all_gather_into_tensor" not in dir(torch.distributed):
  11. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  12. @triton.heuristics(
  13. {
  14. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  15. }
  16. )
  17. @triton.jit
  18. def cross_entropy_fwd_kernel(
  19. loss_ptr, # data ptrs
  20. lse_ptr,
  21. z_loss_ptr,
  22. logits_ptr,
  23. labels_ptr,
  24. smoothing,
  25. logit_scale,
  26. lse_square_scale,
  27. ignore_index,
  28. total_classes,
  29. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  30. n_cols, # shapes
  31. n_rows,
  32. logits_row_stride, # strides
  33. BLOCK_SIZE: tl.constexpr,
  34. HAS_SMOOTHING: tl.constexpr,
  35. # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
  36. SPLIT: tl.constexpr,
  37. ):
  38. row_idx = tl.program_id(0)
  39. col_block_idx = tl.program_id(1)
  40. logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
  41. col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  42. label_idx = tl.load(labels_ptr + row_idx)
  43. logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
  44. tl.float32
  45. ) * logit_scale
  46. max_logits = tl.max(logits, 0)
  47. if HAS_SMOOTHING:
  48. sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
  49. lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
  50. tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
  51. if label_idx == ignore_index:
  52. loss = 0.0
  53. z_loss = 0.0
  54. else:
  55. label_idx -= class_start_idx
  56. if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
  57. n_cols, (col_block_idx + 1) * BLOCK_SIZE
  58. ):
  59. logits_label = tl.load(logits_ptr + label_idx) * logit_scale
  60. if HAS_SMOOTHING:
  61. loss = (
  62. (lse if not SPLIT else 0.0)
  63. - smoothing * sum_logits / total_classes
  64. - (1 - smoothing) * logits_label
  65. )
  66. else:
  67. loss = (lse if not SPLIT else 0.0) - logits_label
  68. else:
  69. # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
  70. if HAS_SMOOTHING:
  71. loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
  72. else:
  73. loss = 0.0
  74. if not SPLIT:
  75. z_loss = lse_square_scale * lse * lse
  76. loss += z_loss
  77. else:
  78. z_loss = 0.0
  79. tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
  80. if not SPLIT:
  81. tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
  82. @triton.heuristics(
  83. {
  84. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  85. }
  86. )
  87. @triton.jit
  88. def cross_entropy_bwd_kernel(
  89. dlogits_ptr, # data ptrs
  90. dloss_ptr,
  91. logits_ptr,
  92. lse_ptr,
  93. labels_ptr,
  94. smoothing,
  95. logit_scale,
  96. lse_square_scale,
  97. ignore_index,
  98. total_classes,
  99. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  100. n_cols, # shapes
  101. logits_row_stride, # strides
  102. dlogits_row_stride,
  103. dloss_row_stride,
  104. BLOCK_SIZE: tl.constexpr,
  105. HAS_SMOOTHING: tl.constexpr,
  106. ):
  107. row_idx = tl.program_id(0)
  108. col_block_idx = tl.program_id(1)
  109. logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
  110. dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
  111. col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  112. label_idx = tl.load(labels_ptr + row_idx)
  113. if label_idx != ignore_index:
  114. dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
  115. else:
  116. dloss = 0.0
  117. logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
  118. tl.float32
  119. ) * logit_scale
  120. lse = tl.load(lse_ptr + row_idx)
  121. probs = tl.exp(logits - lse)
  122. probs += 2.0 * lse_square_scale * lse * probs
  123. label_idx -= class_start_idx
  124. if HAS_SMOOTHING:
  125. smooth_positive = 1.0 - smoothing
  126. smooth_negative = smoothing / total_classes
  127. probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
  128. else:
  129. probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
  130. tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
  131. class CrossEntropyLoss(torch.autograd.Function):
  132. @staticmethod
  133. def forward(
  134. ctx,
  135. logits,
  136. labels,
  137. smoothing=0.0,
  138. logit_scale=1.0,
  139. lse_square_scale=0.0,
  140. ignore_index=-100,
  141. inplace_backward=False,
  142. process_group=None,
  143. ):
  144. n_rows, n_cols = logits.shape
  145. assert labels.shape == (n_rows,)
  146. world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
  147. total_classes = world_size * n_cols
  148. rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
  149. class_start_idx = rank * n_cols
  150. if logits.stride(-1) != 1:
  151. logits = logits.contiguous()
  152. # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
  153. MAX_BLOCK_SIZE = 64 * 1024
  154. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
  155. num_warps = (
  156. 4
  157. if BLOCK_SIZE < 2048
  158. else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
  159. )
  160. # We may split the lse computation across multiple blocks, then do a reduction
  161. # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
  162. # where having just one thread block processing more than 64k elements is slow.
  163. split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
  164. n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
  165. loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
  166. losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
  167. lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
  168. z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
  169. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  170. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  171. with torch.cuda.device(logits.device.index):
  172. cross_entropy_fwd_kernel[(n_rows, n_splits)](
  173. losses, # data ptrs
  174. lse,
  175. z_losses,
  176. logits,
  177. labels,
  178. smoothing,
  179. logit_scale,
  180. lse_square_scale,
  181. ignore_index,
  182. total_classes,
  183. class_start_idx,
  184. n_cols, # shapes
  185. n_rows,
  186. logits.stride(0), # strides
  187. BLOCK_SIZE=BLOCK_SIZE, # constants
  188. num_warps=num_warps,
  189. SPLIT=split,
  190. )
  191. if split:
  192. # If there's no smoothing, if labels are in the vocab of this partition, losses contains
  193. # - predicted logit, and 0 otherwise.
  194. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
  195. # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
  196. # For labels not in the vocab of this partition, losses contains
  197. # -0.1 * sum logit / total_classes.
  198. if n_splits > 1:
  199. lse = torch.logsumexp(lse, dim=0)
  200. losses = losses.sum(dim=0)
  201. if world_size > 1:
  202. lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
  203. torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
  204. handle_losses = torch.distributed.all_reduce(
  205. losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
  206. )
  207. lse = torch.logsumexp(lse_allgather, dim=0)
  208. handle_losses.wait()
  209. # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
  210. # we just have to add the (global) lse.
  211. # If there's smoothing=0.1, the total losses are
  212. # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
  213. # Again, we just have to add the (global) lse.
  214. losses += lse
  215. if lse_square_scale != 0.0:
  216. z_losses = lse_square_scale * lse.square()
  217. z_losses.masked_fill_(labels == ignore_index, 0.0)
  218. losses += z_losses
  219. else:
  220. z_losses = torch.zeros_like(losses)
  221. losses.masked_fill_(labels == ignore_index, 0.0)
  222. ctx.save_for_backward(logits, lse, labels)
  223. ctx.mark_non_differentiable(z_losses)
  224. ctx.smoothing = smoothing
  225. ctx.logit_scale = logit_scale
  226. ctx.lse_square_scale = lse_square_scale
  227. ctx.ignore_index = ignore_index
  228. ctx.total_classes = total_classes
  229. ctx.class_start_idx = class_start_idx
  230. ctx.inplace_backward = inplace_backward
  231. return losses, z_losses
  232. @staticmethod
  233. def backward(ctx, grad_losses, grad_z_losses):
  234. del grad_z_losses # z_losses are only for logging.
  235. logits, lse, labels = ctx.saved_tensors
  236. dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
  237. n_rows, n_cols = logits.shape
  238. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
  239. num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
  240. grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
  241. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  242. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  243. with torch.cuda.device(logits.device.index):
  244. cross_entropy_bwd_kernel[grid](
  245. dlogits, # data ptrs
  246. grad_losses,
  247. logits,
  248. lse,
  249. labels,
  250. ctx.smoothing,
  251. ctx.logit_scale,
  252. ctx.lse_square_scale,
  253. ctx.ignore_index,
  254. ctx.total_classes,
  255. ctx.class_start_idx,
  256. n_cols, # shapes
  257. logits.stride(0), # strides
  258. dlogits.stride(0),
  259. grad_losses.stride(0),
  260. BLOCK_SIZE=BLOCK_SIZE, # constants
  261. num_warps=num_warps,
  262. )
  263. return dlogits, None, None, None, None, None, None, None, None
  264. def cross_entropy_loss(
  265. logits: torch.Tensor,
  266. labels: torch.Tensor,
  267. label_smoothing: float = 0.0,
  268. logit_scale: float = 1.0,
  269. lse_square_scale: float = 0.0,
  270. ignore_index=-100,
  271. inplace_backward: bool = False,
  272. process_group=None,
  273. ) -> Tuple[torch.Tensor, torch.Tensor]:
  274. """
  275. Arguments:
  276. logits: (batch, vocab_size)
  277. labels: (batch,)
  278. label_smoothing: float
  279. logit_scale: float. Multiply logits by this scale before calculating the loss.
  280. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
  281. This is also referred to as "z-loss".
  282. ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
  283. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
  284. This saves memory.
  285. process_group: if not None, we're doing Tensor Parallel: each process is responsible for
  286. one part of the vocab. The loss will be aggregated across processes.
  287. Returns:
  288. losses: (batch,), float
  289. z_losses: (batch,), float
  290. """
  291. return CrossEntropyLoss.apply(
  292. logits,
  293. labels,
  294. label_smoothing,
  295. logit_scale,
  296. lse_square_scale,
  297. ignore_index,
  298. inplace_backward,
  299. process_group,
  300. )