cross_entropy.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Tuple, Optional, Union
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  7. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  8. # version of PyTorch. The following 2 lines are for backward compatibility with
  9. # older PyTorch.
  10. if "all_gather_into_tensor" not in dir(torch.distributed):
  11. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  12. @triton.heuristics(
  13. {
  14. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  15. }
  16. )
  17. @triton.jit
  18. def cross_entropy_fwd_kernel(
  19. loss_ptr, # data ptrs
  20. lse_ptr,
  21. z_loss_ptr,
  22. logits_ptr,
  23. labels_ptr,
  24. smoothing,
  25. logit_scale,
  26. lse_square_scale,
  27. ignore_index,
  28. total_classes,
  29. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  30. n_cols, # shapes
  31. logits_row_stride, # strides
  32. BLOCK_SIZE: tl.constexpr,
  33. HAS_SMOOTHING: tl.constexpr,
  34. # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
  35. SPLIT: tl.constexpr,
  36. PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
  37. ):
  38. row_idx = tl.program_id(0)
  39. logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
  40. sum_logits = 0.0 # For smoothing
  41. if not PRECOMPUTED_LSE:
  42. # Statistics for online softmax
  43. m_i = -float("inf")
  44. l_i = 0.0
  45. for col_offset in range(0, n_cols, BLOCK_SIZE):
  46. cols = col_offset + tl.arange(0, BLOCK_SIZE)
  47. logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
  48. tl.float32
  49. ) * logit_scale
  50. if HAS_SMOOTHING:
  51. sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
  52. m_i_new = tl.maximum(m_i, tl.max(logits))
  53. l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
  54. m_i = m_i_new
  55. lse = tl.log(l_i) + m_i
  56. tl.store(lse_ptr + row_idx, lse)
  57. else:
  58. lse = tl.load(lse_ptr + row_idx)
  59. label_idx = tl.load(labels_ptr + row_idx)
  60. if label_idx == ignore_index:
  61. loss = 0.0
  62. z_loss = 0.0
  63. else:
  64. label_idx -= class_start_idx
  65. if label_idx >= 0 and label_idx < n_cols:
  66. logits_label = tl.load(logits_ptr + label_idx) * logit_scale
  67. if HAS_SMOOTHING:
  68. loss = (
  69. (lse if not SPLIT else 0.0)
  70. - smoothing * sum_logits / total_classes
  71. - (1 - smoothing) * logits_label
  72. )
  73. else:
  74. loss = (lse if not SPLIT else 0.0) - logits_label
  75. else:
  76. # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
  77. if HAS_SMOOTHING:
  78. loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
  79. else:
  80. loss = 0.0
  81. if not SPLIT:
  82. z_loss = lse_square_scale * lse * lse
  83. loss += z_loss
  84. else:
  85. z_loss = 0.0
  86. tl.store(loss_ptr + row_idx, loss)
  87. if not SPLIT:
  88. tl.store(z_loss_ptr + row_idx, z_loss)
  89. @triton.heuristics(
  90. {
  91. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  92. }
  93. )
  94. @triton.jit
  95. def cross_entropy_bwd_kernel(
  96. dlogits_ptr, # data ptrs
  97. dloss_ptr,
  98. logits_ptr,
  99. lse_ptr,
  100. labels_ptr,
  101. smoothing,
  102. logit_scale,
  103. lse_square_scale,
  104. ignore_index,
  105. total_classes,
  106. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  107. n_cols, # shapes
  108. logits_row_stride, # strides
  109. dlogits_row_stride,
  110. dloss_row_stride,
  111. BLOCK_SIZE: tl.constexpr,
  112. HAS_SMOOTHING: tl.constexpr,
  113. ):
  114. row_idx = tl.program_id(0)
  115. col_block_idx = tl.program_id(1)
  116. logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
  117. dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
  118. col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  119. label_idx = tl.load(labels_ptr + row_idx)
  120. if label_idx != ignore_index:
  121. dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
  122. else:
  123. dloss = 0.0
  124. logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
  125. tl.float32
  126. ) * logit_scale
  127. lse = tl.load(lse_ptr + row_idx)
  128. probs = tl.exp(logits - lse)
  129. probs += 2.0 * lse_square_scale * lse * probs
  130. label_idx -= class_start_idx
  131. if HAS_SMOOTHING:
  132. smooth_positive = 1.0 - smoothing
  133. smooth_negative = smoothing / total_classes
  134. probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
  135. else:
  136. probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
  137. tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
  138. class CrossEntropyLoss(torch.autograd.Function):
  139. @staticmethod
  140. def forward(
  141. ctx,
  142. logits,
  143. labels,
  144. precomputed_lse=None,
  145. smoothing=0.0,
  146. logit_scale=1.0,
  147. lse_square_scale=0.0,
  148. ignore_index=-100,
  149. inplace_backward=False,
  150. process_group=None,
  151. ):
  152. n_rows, n_cols = logits.shape
  153. assert labels.shape == (n_rows,)
  154. world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
  155. total_classes = world_size * n_cols
  156. rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
  157. class_start_idx = rank * n_cols
  158. use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
  159. if logits.stride(-1) != 1:
  160. logits = logits.contiguous()
  161. MAX_BLOCK_SIZE = 16 * 1024
  162. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
  163. num_warps = (
  164. 4
  165. if BLOCK_SIZE < 2048
  166. else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
  167. )
  168. losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
  169. if use_precomputed_lse:
  170. assert precomputed_lse.shape == (n_rows,)
  171. lse = precomputed_lse.contiguous()
  172. else:
  173. lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
  174. z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
  175. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  176. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  177. with torch.cuda.device(logits.device.index):
  178. cross_entropy_fwd_kernel[(n_rows,)](
  179. losses, # data ptrs
  180. lse,
  181. z_losses,
  182. logits,
  183. labels,
  184. smoothing,
  185. logit_scale,
  186. lse_square_scale,
  187. ignore_index,
  188. total_classes,
  189. class_start_idx,
  190. n_cols, # shapes
  191. logits.stride(0), # strides
  192. BLOCK_SIZE=BLOCK_SIZE, # constants
  193. SPLIT=world_size > 1,
  194. PRECOMPUTED_LSE=use_precomputed_lse,
  195. num_warps=num_warps,
  196. )
  197. if world_size > 1:
  198. # If there's no smoothing, if labels are in the vocab of this partition, losses contains
  199. # - predicted logit, and 0 otherwise.
  200. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
  201. # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
  202. # For labels not in the vocab of this partition, losses contains
  203. # -0.1 * sum logit / total_classes.
  204. if world_size > 1:
  205. lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
  206. torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
  207. handle_losses = torch.distributed.all_reduce(
  208. losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
  209. )
  210. lse = torch.logsumexp(lse_allgather, dim=0)
  211. handle_losses.wait()
  212. # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
  213. # we just have to add the (global) lse.
  214. # If there's smoothing=0.1, the total losses are
  215. # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
  216. # Again, we just have to add the (global) lse.
  217. losses += lse
  218. if lse_square_scale != 0.0:
  219. z_losses = lse_square_scale * lse.square()
  220. z_losses.masked_fill_(labels == ignore_index, 0.0)
  221. losses += z_losses
  222. else:
  223. z_losses = torch.zeros_like(losses)
  224. losses.masked_fill_(labels == ignore_index, 0.0)
  225. ctx.save_for_backward(logits, lse, labels)
  226. ctx.mark_non_differentiable(z_losses)
  227. ctx.smoothing = smoothing
  228. ctx.logit_scale = logit_scale
  229. ctx.lse_square_scale = lse_square_scale
  230. ctx.ignore_index = ignore_index
  231. ctx.total_classes = total_classes
  232. ctx.class_start_idx = class_start_idx
  233. ctx.inplace_backward = inplace_backward
  234. return losses, z_losses
  235. @staticmethod
  236. def backward(ctx, grad_losses, grad_z_losses):
  237. del grad_z_losses # z_losses are only for logging.
  238. logits, lse, labels = ctx.saved_tensors
  239. dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
  240. n_rows, n_cols = logits.shape
  241. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
  242. num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
  243. grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
  244. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  245. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  246. with torch.cuda.device(logits.device.index):
  247. cross_entropy_bwd_kernel[grid](
  248. dlogits, # data ptrs
  249. grad_losses,
  250. logits,
  251. lse,
  252. labels,
  253. ctx.smoothing,
  254. ctx.logit_scale,
  255. ctx.lse_square_scale,
  256. ctx.ignore_index,
  257. ctx.total_classes,
  258. ctx.class_start_idx,
  259. n_cols, # shapes
  260. logits.stride(0), # strides
  261. dlogits.stride(0),
  262. grad_losses.stride(0),
  263. BLOCK_SIZE=BLOCK_SIZE, # constants
  264. num_warps=num_warps,
  265. )
  266. return dlogits, None, None, None, None, None, None, None, None, None
  267. def cross_entropy_loss(
  268. logits: torch.Tensor,
  269. labels: torch.Tensor,
  270. precomputed_lse: Optional[torch.Tensor] = None,
  271. label_smoothing: float = 0.0,
  272. logit_scale: float = 1.0,
  273. lse_square_scale: float = 0.0,
  274. ignore_index=-100,
  275. inplace_backward: bool = False,
  276. process_group=None,
  277. ) -> Tuple[torch.Tensor, torch.Tensor]:
  278. """
  279. Arguments:
  280. logits: (batch, vocab_size)
  281. labels: (batch,)
  282. label_smoothing: float
  283. logit_scale: float. Multiply logits by this scale before calculating the loss.
  284. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
  285. This is also referred to as "z-loss".
  286. ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
  287. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
  288. This saves memory.
  289. process_group: if not None, we're doing Tensor Parallel: each process is responsible for
  290. one part of the vocab. The loss will be aggregated across processes.
  291. Returns:
  292. losses: (batch,), float
  293. z_losses: (batch,), float
  294. """
  295. return CrossEntropyLoss.apply(
  296. logits,
  297. labels,
  298. precomputed_lse,
  299. label_smoothing,
  300. logit_scale,
  301. lse_square_scale,
  302. ignore_index,
  303. inplace_backward,
  304. process_group,
  305. )