cross_entropy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Tuple, Optional, Union
  3. import torch
  4. import torch.nn.functional as F
  5. import triton
  6. import triton.language as tl
  7. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  8. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  9. # version of PyTorch. The following 2 lines are for backward compatibility with
  10. # older PyTorch.
  11. if "all_gather_into_tensor" not in dir(torch.distributed):
  12. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  13. @triton.heuristics(
  14. {
  15. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  16. }
  17. )
  18. @triton.jit
  19. def cross_entropy_fwd_kernel(
  20. loss_ptr, # data ptrs
  21. lse_ptr,
  22. z_loss_ptr,
  23. logits_ptr,
  24. labels_ptr,
  25. smoothing,
  26. logit_scale,
  27. lse_square_scale,
  28. ignore_index,
  29. total_classes,
  30. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  31. n_cols, # shapes
  32. logits_row_stride, # strides
  33. BLOCK_SIZE: tl.constexpr,
  34. HAS_SMOOTHING: tl.constexpr,
  35. # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
  36. SPLIT: tl.constexpr,
  37. PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
  38. ):
  39. row_idx = tl.program_id(0)
  40. logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
  41. sum_logits = 0.0 # For smoothing
  42. if not PRECOMPUTED_LSE:
  43. # Statistics for online softmax
  44. m_i = -float("inf")
  45. l_i = 0.0
  46. for col_offset in range(0, n_cols, BLOCK_SIZE):
  47. cols = col_offset + tl.arange(0, BLOCK_SIZE)
  48. logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
  49. tl.float32
  50. ) * logit_scale
  51. if HAS_SMOOTHING:
  52. sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
  53. m_i_new = tl.maximum(m_i, tl.max(logits))
  54. l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
  55. m_i = m_i_new
  56. lse = tl.log(l_i) + m_i
  57. tl.store(lse_ptr + row_idx, lse)
  58. else:
  59. lse = tl.load(lse_ptr + row_idx)
  60. label_idx = tl.load(labels_ptr + row_idx)
  61. if label_idx == ignore_index:
  62. loss = 0.0
  63. z_loss = 0.0
  64. else:
  65. label_idx -= class_start_idx
  66. if label_idx >= 0 and label_idx < n_cols:
  67. logits_label = tl.load(logits_ptr + label_idx) * logit_scale
  68. if HAS_SMOOTHING:
  69. loss = (
  70. (lse if not SPLIT else 0.0)
  71. - smoothing * sum_logits / total_classes
  72. - (1 - smoothing) * logits_label
  73. )
  74. else:
  75. loss = (lse if not SPLIT else 0.0) - logits_label
  76. else:
  77. # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
  78. if HAS_SMOOTHING:
  79. loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
  80. else:
  81. loss = 0.0
  82. if not SPLIT:
  83. z_loss = lse_square_scale * lse * lse
  84. loss += z_loss
  85. else:
  86. z_loss = 0.0
  87. tl.store(loss_ptr + row_idx, loss)
  88. if not SPLIT:
  89. tl.store(z_loss_ptr + row_idx, z_loss)
  90. @triton.heuristics(
  91. {
  92. "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
  93. }
  94. )
  95. @triton.jit
  96. def cross_entropy_bwd_kernel(
  97. dlogits_ptr, # data ptrs
  98. dloss_ptr,
  99. logits_ptr,
  100. lse_ptr,
  101. labels_ptr,
  102. smoothing,
  103. logit_scale,
  104. lse_square_scale,
  105. ignore_index,
  106. total_classes,
  107. class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
  108. n_cols, # shapes
  109. logits_row_stride, # strides
  110. dlogits_row_stride,
  111. dloss_row_stride,
  112. BLOCK_SIZE: tl.constexpr,
  113. HAS_SMOOTHING: tl.constexpr,
  114. ):
  115. row_idx = tl.program_id(0)
  116. col_block_idx = tl.program_id(1)
  117. logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
  118. dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
  119. col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  120. label_idx = tl.load(labels_ptr + row_idx)
  121. if label_idx != ignore_index:
  122. dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
  123. else:
  124. dloss = 0.0
  125. logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
  126. tl.float32
  127. ) * logit_scale
  128. lse = tl.load(lse_ptr + row_idx)
  129. probs = tl.exp(logits - lse)
  130. probs += 2.0 * lse_square_scale * lse * probs
  131. label_idx -= class_start_idx
  132. if HAS_SMOOTHING:
  133. smooth_positive = 1.0 - smoothing
  134. smooth_negative = smoothing / total_classes
  135. probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
  136. else:
  137. probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
  138. tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
  139. class CrossEntropyLoss(torch.autograd.Function):
  140. @staticmethod
  141. def forward(
  142. ctx,
  143. logits,
  144. labels,
  145. precomputed_lse=None,
  146. smoothing=0.0,
  147. logit_scale=1.0,
  148. lse_square_scale=0.0,
  149. ignore_index=-100,
  150. inplace_backward=False,
  151. process_group=None,
  152. ):
  153. # For some reason Triton generates wrong code when labels has dtype long and its address
  154. # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
  155. if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
  156. labels = F.pad(labels, (0, 1))[..., :-1]
  157. assert labels.data_ptr() % 16 == 0
  158. n_rows, n_cols = logits.shape
  159. assert labels.shape == (n_rows,)
  160. world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
  161. total_classes = world_size * n_cols
  162. rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
  163. class_start_idx = rank * n_cols
  164. use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
  165. if logits.stride(-1) != 1:
  166. logits = logits.contiguous()
  167. MAX_BLOCK_SIZE = 16 * 1024
  168. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
  169. num_warps = (
  170. 4
  171. if BLOCK_SIZE < 2048
  172. else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
  173. )
  174. losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
  175. if use_precomputed_lse:
  176. assert precomputed_lse.shape == (n_rows,)
  177. lse = precomputed_lse.contiguous()
  178. else:
  179. lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
  180. z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
  181. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  182. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  183. with torch.cuda.device(logits.device.index):
  184. cross_entropy_fwd_kernel[(n_rows,)](
  185. losses, # data ptrs
  186. lse,
  187. z_losses,
  188. logits,
  189. labels,
  190. smoothing,
  191. logit_scale,
  192. lse_square_scale,
  193. ignore_index,
  194. total_classes,
  195. class_start_idx,
  196. n_cols, # shapes
  197. logits.stride(0), # strides
  198. BLOCK_SIZE=BLOCK_SIZE, # constants
  199. SPLIT=world_size > 1,
  200. PRECOMPUTED_LSE=use_precomputed_lse,
  201. num_warps=num_warps,
  202. )
  203. if world_size > 1:
  204. # If there's no smoothing, if labels are in the vocab of this partition, losses contains
  205. # - predicted logit, and 0 otherwise.
  206. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
  207. # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
  208. # For labels not in the vocab of this partition, losses contains
  209. # -0.1 * sum logit / total_classes.
  210. if world_size > 1:
  211. lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
  212. torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
  213. handle_losses = torch.distributed.all_reduce(
  214. losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
  215. )
  216. lse = torch.logsumexp(lse_allgather, dim=0)
  217. handle_losses.wait()
  218. # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
  219. # we just have to add the (global) lse.
  220. # If there's smoothing=0.1, the total losses are
  221. # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
  222. # Again, we just have to add the (global) lse.
  223. losses += lse
  224. if lse_square_scale != 0.0:
  225. z_losses = lse_square_scale * lse.square()
  226. z_losses.masked_fill_(labels == ignore_index, 0.0)
  227. losses += z_losses
  228. else:
  229. z_losses = torch.zeros_like(losses)
  230. losses.masked_fill_(labels == ignore_index, 0.0)
  231. ctx.save_for_backward(logits, lse, labels)
  232. ctx.mark_non_differentiable(z_losses)
  233. ctx.smoothing = smoothing
  234. ctx.logit_scale = logit_scale
  235. ctx.lse_square_scale = lse_square_scale
  236. ctx.ignore_index = ignore_index
  237. ctx.total_classes = total_classes
  238. ctx.class_start_idx = class_start_idx
  239. ctx.inplace_backward = inplace_backward
  240. return losses, z_losses
  241. @staticmethod
  242. def backward(ctx, grad_losses, grad_z_losses):
  243. del grad_z_losses # z_losses are only for logging.
  244. logits, lse, labels = ctx.saved_tensors
  245. dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
  246. n_rows, n_cols = logits.shape
  247. BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
  248. num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
  249. grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
  250. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  251. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  252. with torch.cuda.device(logits.device.index):
  253. cross_entropy_bwd_kernel[grid](
  254. dlogits, # data ptrs
  255. grad_losses,
  256. logits,
  257. lse,
  258. labels,
  259. ctx.smoothing,
  260. ctx.logit_scale,
  261. ctx.lse_square_scale,
  262. ctx.ignore_index,
  263. ctx.total_classes,
  264. ctx.class_start_idx,
  265. n_cols, # shapes
  266. logits.stride(0), # strides
  267. dlogits.stride(0),
  268. grad_losses.stride(0),
  269. BLOCK_SIZE=BLOCK_SIZE, # constants
  270. num_warps=num_warps,
  271. )
  272. return dlogits, None, None, None, None, None, None, None, None, None
  273. def cross_entropy_loss(
  274. logits: torch.Tensor,
  275. labels: torch.Tensor,
  276. precomputed_lse: Optional[torch.Tensor] = None,
  277. label_smoothing: float = 0.0,
  278. logit_scale: float = 1.0,
  279. lse_square_scale: float = 0.0,
  280. ignore_index=-100,
  281. inplace_backward: bool = False,
  282. process_group=None,
  283. ) -> Tuple[torch.Tensor, torch.Tensor]:
  284. """
  285. Arguments:
  286. logits: (batch, vocab_size)
  287. labels: (batch,)
  288. label_smoothing: float
  289. logit_scale: float. Multiply logits by this scale before calculating the loss.
  290. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
  291. This is also referred to as "z-loss".
  292. ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
  293. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
  294. This saves memory.
  295. process_group: if not None, we're doing Tensor Parallel: each process is responsible for
  296. one part of the vocab. The loss will be aggregated across processes.
  297. Returns:
  298. losses: (batch,), float
  299. z_losses: (batch,), float
  300. """
  301. return CrossEntropyLoss.apply(
  302. logits,
  303. labels,
  304. precomputed_lse,
  305. label_smoothing,
  306. logit_scale,
  307. lse_square_scale,
  308. ignore_index,
  309. inplace_backward,
  310. process_group,
  311. )