1
0

flash_blocksparse_attn_interface.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
  2. import flash_attn_cuda
  3. import torch
  4. import torch.nn as nn
  5. def convert_blockmask(blockmask, causal):
  6. """Convert from the 0-1 format to the format used by the CUDA code.
  7. 0 means the block is skipped.
  8. nonzero means the block is not skipped.
  9. Argument:
  10. blockmask: (row, col): a 0-1 tensor
  11. Return:
  12. blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
  13. indices of the nonzero blocks, padded with -1 to reach length @row.
  14. The indices are multiplied by 4, with the smallest bit used to encode whether
  15. it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
  16. the last nonzero in its row..
  17. """
  18. assert not causal
  19. # TD [2022-05-13]: The indexing and sorting is very tricky
  20. nrow, ncol = blockmask.shape
  21. # Sort does not support bool on CUDA
  22. blockmask = blockmask.to(dtype=torch.uint8)
  23. nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
  24. nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
  25. last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
  26. last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
  27. torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
  28. ]
  29. first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
  30. first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
  31. torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
  32. ]
  33. nonzero_idx = nonzero_sorted_rowidx * 4
  34. nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
  35. nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
  36. nonzero_idx[nonzero_val == 0] = -1
  37. return nonzero_idx.T.contiguous().to(dtype=torch.int32)
  38. def _flash_blocksparse_attn_forward(
  39. qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
  40. ):
  41. context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
  42. qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
  43. )
  44. # if context.isnan().any() or softmax_lse.isnan().any():
  45. # breakpoint()
  46. S_dmask = rest[0] if return_softmax else None
  47. return context, softmax_lse, S_dmask
  48. def _flash_blocksparse_attn_backward(
  49. dout,
  50. qkv,
  51. out,
  52. S_dmask,
  53. softmax_lse,
  54. cu_seqlens,
  55. blockmask,
  56. dropout_p,
  57. max_s,
  58. softmax_scale,
  59. causal,
  60. ):
  61. dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
  62. dout,
  63. qkv,
  64. out,
  65. S_dmask,
  66. softmax_lse,
  67. cu_seqlens,
  68. blockmask,
  69. dropout_p,
  70. softmax_scale,
  71. max_s,
  72. causal,
  73. None,
  74. )
  75. # if dqkv.isnan().any() or softmax_d.isnan().any():
  76. # breakpoint()
  77. return dqkv
  78. class FlashBlocksparseAttnFun(torch.autograd.Function):
  79. @staticmethod
  80. def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
  81. # Save rng_state because the backward pass will regenerate the dropout mask
  82. rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
  83. if softmax_scale is None:
  84. softmax_scale = qkv.shape[-1] ** (-0.5)
  85. context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
  86. qkv,
  87. cu_seqlens,
  88. blockmask,
  89. dropout_p,
  90. max_s,
  91. softmax_scale,
  92. causal=causal,
  93. return_softmax=False,
  94. )
  95. ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
  96. ctx.dropout_p = dropout_p
  97. ctx.max_s = max_s
  98. ctx.softmax_scale = softmax_scale
  99. ctx.causal = causal
  100. return context
  101. @staticmethod
  102. def backward(ctx, dout):
  103. qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
  104. if rng_state is not None:
  105. cur_rng_state = torch.cuda.get_rng_state()
  106. torch.cuda.set_rng_state(rng_state)
  107. # S_dmask is None, temporarily use another tensor just to get it running
  108. dqkv = _flash_blocksparse_attn_backward(
  109. dout,
  110. qkv,
  111. context,
  112. context,
  113. softmax_lse,
  114. cu_seqlens,
  115. blockmask,
  116. ctx.dropout_p,
  117. ctx.max_s,
  118. ctx.softmax_scale,
  119. ctx.causal,
  120. )
  121. if rng_state is not None:
  122. torch.cuda.set_rng_state(cur_rng_state)
  123. return dqkv, None, None, None, None, None, None, None
  124. # We duplicate code to return both the output and the softmax for testing
  125. # Returning both makes backward a bit slower, so we want to keep using the other version for speed.
  126. class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
  127. @staticmethod
  128. def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
  129. # Save rng_state because the backward pass is gonna regenerate the dropout mask
  130. rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
  131. if softmax_scale is None:
  132. softmax_scale = qkv.shape[-1] ** (-0.5)
  133. context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
  134. qkv,
  135. cu_seqlens,
  136. blockmask,
  137. dropout_p,
  138. max_s,
  139. softmax_scale,
  140. causal=causal,
  141. return_softmax=True,
  142. )
  143. ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
  144. ctx.dropout_p = dropout_p
  145. ctx.max_s = max_s
  146. ctx.softmax_scale = softmax_scale
  147. ctx.causal = causal
  148. return context, S_dmask, softmax_lse
  149. @staticmethod
  150. def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
  151. qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
  152. if rng_state is not None:
  153. cur_rng_state = torch.cuda.get_rng_state()
  154. torch.cuda.set_rng_state(rng_state)
  155. dqkv = _flash_blocksparse_attn_backward(
  156. dout,
  157. qkv,
  158. context,
  159. S_dmask,
  160. softmax_lse,
  161. cu_seqlens,
  162. blockmask,
  163. ctx.dropout_p,
  164. ctx.max_s,
  165. ctx.softmax_scale,
  166. ctx.causal,
  167. )
  168. if rng_state is not None:
  169. torch.cuda.set_rng_state(cur_rng_state)
  170. return dqkv, None, None, None, None, None, None
  171. def flash_blocksparse_attn_func(
  172. qkv,
  173. cu_seqlens,
  174. blockmask,
  175. dropout_p,
  176. max_s,
  177. softmax_scale=None,
  178. causal=False,
  179. return_attn_probs=False,
  180. convert_mask=True,
  181. ):
  182. """dropout_p should be set to 0.0 during evaluation"""
  183. func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
  184. if convert_mask:
  185. blockmask = convert_blockmask(blockmask, causal=causal)
  186. return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)