bert_padding.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
  2. import torch
  3. import torch.nn.functional as F
  4. from einops import rearrange, repeat
  5. class IndexFirstAxis(torch.autograd.Function):
  6. @staticmethod
  7. def forward(ctx, input, indices):
  8. ctx.save_for_backward(indices)
  9. assert input.ndim >= 2
  10. ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
  11. second_dim = other_shape.numel()
  12. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
  13. # return input[indices]
  14. return torch.gather(
  15. rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
  16. ).reshape(-1, *other_shape)
  17. @staticmethod
  18. def backward(ctx, grad_output):
  19. (indices,) = ctx.saved_tensors
  20. assert grad_output.ndim >= 2
  21. other_shape = grad_output.shape[1:]
  22. grad_output = rearrange(grad_output, "b ... -> b (...)")
  23. grad_input = torch.zeros(
  24. [ctx.first_axis_dim, grad_output.shape[1]],
  25. device=grad_output.device,
  26. dtype=grad_output.dtype,
  27. )
  28. # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
  29. # grad_input[indices] = grad_output
  30. grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
  31. return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
  32. index_first_axis = IndexFirstAxis.apply
  33. class IndexPutFirstAxis(torch.autograd.Function):
  34. @staticmethod
  35. def forward(ctx, values, indices, first_axis_dim):
  36. ctx.save_for_backward(indices)
  37. assert indices.ndim == 1
  38. assert values.ndim >= 2
  39. output = torch.zeros(
  40. first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
  41. )
  42. # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
  43. output[indices] = values
  44. # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
  45. return output
  46. @staticmethod
  47. def backward(ctx, grad_output):
  48. (indices,) = ctx.saved_tensors
  49. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
  50. grad_values = grad_output[indices]
  51. # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
  52. return grad_values, None, None
  53. index_put_first_axis = IndexPutFirstAxis.apply
  54. class IndexFirstAxisResidual(torch.autograd.Function):
  55. @staticmethod
  56. def forward(ctx, input, indices):
  57. ctx.save_for_backward(indices)
  58. assert input.ndim >= 2
  59. ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
  60. second_dim = other_shape.numel()
  61. # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
  62. output = input[indices]
  63. # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
  64. # memory format to channel_first. In other words, input might not be contiguous.
  65. # If we don't detach, Pytorch complains about output being a view and is being modified inplace
  66. return output, input.detach()
  67. @staticmethod
  68. def backward(ctx, grad_output, grad_residual):
  69. (indices,) = ctx.saved_tensors
  70. assert grad_output.ndim >= 2
  71. other_shape = grad_output.shape[1:]
  72. assert grad_residual.shape[1:] == other_shape
  73. grad_input = grad_residual
  74. # grad_input[indices] += grad_output
  75. indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
  76. indices = indices.expand_as(grad_output)
  77. grad_input.scatter_add_(0, indices, grad_output)
  78. return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
  79. index_first_axis_residual = IndexFirstAxisResidual.apply
  80. def unpad_input(hidden_states, attention_mask, unused_mask=None):
  81. """
  82. Arguments:
  83. hidden_states: (batch, seqlen, ...)
  84. attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
  85. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
  86. Return:
  87. hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
  88. indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
  89. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
  90. max_seqlen_in_batch: int
  91. seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
  92. """
  93. all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
  94. seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
  95. used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  96. indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
  97. max_seqlen_in_batch = seqlens_in_batch.max().item()
  98. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  99. # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
  100. # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
  101. # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
  102. # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
  103. # so we write custom forward and backward to make it a bit faster.
  104. return (
  105. index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
  106. indices,
  107. cu_seqlens,
  108. max_seqlen_in_batch,
  109. used_seqlens_in_batch,
  110. )
  111. def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
  112. """
  113. Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
  114. The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
  115. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
  116. ```
  117. [
  118. [2, 3, 0, 0, 0, 0],
  119. [3, 2, 0, 0, 0, 0],
  120. [6, 0, 0, 0, 0, 0]
  121. ]
  122. ```
  123. , which refers to the 3D-attention mask:
  124. ```
  125. [
  126. [
  127. [1, 0, 0, 0, 0, 0],
  128. [1, 1, 0, 0, 0, 0],
  129. [0, 0, 1, 0, 0, 0],
  130. [0, 0, 1, 1, 0, 0],
  131. [0, 0, 1, 1, 1, 0],
  132. [0, 0, 0, 0, 0, 1]
  133. ],
  134. [
  135. [1, 0, 0, 0, 0, 0],
  136. [1, 1, 0, 0, 0, 0],
  137. [1, 1, 1, 0, 0, 0],
  138. [0, 0, 0, 1, 0, 0],
  139. [0, 0, 0, 1, 1, 0],
  140. [0, 0, 0, 0, 0, 1]
  141. ],
  142. [
  143. [1, 0, 0, 0, 0, 0],
  144. [1, 1, 0, 0, 0, 0],
  145. [1, 1, 1, 0, 0, 0],
  146. [1, 1, 1, 1, 0, 0],
  147. [1, 1, 1, 1, 1, 0],
  148. [1, 1, 1, 1, 1, 1]
  149. ]
  150. ]
  151. ```.
  152. Arguments:
  153. hidden_states: (batch, seqlen, ...)
  154. attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
  155. Return:
  156. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
  157. indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
  158. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
  159. max_seqlen_in_batch: int
  160. """
  161. length = attention_mask_in_length.sum(dim=-1)
  162. seqlen = attention_mask_in_length.size(-1)
  163. attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
  164. real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
  165. seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
  166. indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
  167. max_seqlen_in_batch = seqlens_in_batch.max().item()
  168. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  169. # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
  170. # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
  171. # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
  172. # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
  173. # so we write custom forward and backward to make it a bit faster.
  174. return (
  175. index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
  176. indices,
  177. cu_seqlens,
  178. max_seqlen_in_batch,
  179. )
  180. def pad_input(hidden_states, indices, batch, seqlen):
  181. """
  182. Arguments:
  183. hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
  184. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
  185. batch: int, batch size for the padded sequence.
  186. seqlen: int, maximum sequence length for the padded sequence.
  187. Return:
  188. hidden_states: (batch, seqlen, ...)
  189. """
  190. dim = hidden_states.shape[-1]
  191. # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
  192. # output[indices] = hidden_states
  193. output = index_put_first_axis(hidden_states, indices, batch * seqlen)
  194. return rearrange(output, "(b s) ... -> b s ...", b=batch)