123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
- import flash_attn_cuda
- import torch
- import torch.nn as nn
- def convert_blockmask(blockmask, causal):
- """Convert from the 0-1 format to the format used by the CUDA code.
- 0 means the block is skipped.
- nonzero means the block is not skipped.
- Argument:
- blockmask: (row, col): a 0-1 tensor
- Return:
- blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
- indices of the nonzero blocks, padded with -1 to reach length @row.
- The indices are multiplied by 4, with the smallest bit used to encode whether
- it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
- the last nonzero in its row..
- """
- assert not causal
- # TD [2022-05-13]: The indexing and sorting is very tricky
- nrow, ncol = blockmask.shape
- # Sort does not support bool on CUDA
- blockmask = blockmask.to(dtype=torch.uint8)
- nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
- nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
- last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
- last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
- torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
- ]
- first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
- first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
- torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
- ]
- nonzero_idx = nonzero_sorted_rowidx * 4
- nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
- nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
- nonzero_idx[nonzero_val == 0] = -1
- return nonzero_idx.T.contiguous().to(dtype=torch.int32)
- def _flash_blocksparse_attn_forward(
- qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
- ):
- context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
- qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
- )
- # if context.isnan().any() or softmax_lse.isnan().any():
- # breakpoint()
- S_dmask = rest[0] if return_softmax else None
- return context, softmax_lse, S_dmask
- def _flash_blocksparse_attn_backward(
- dout,
- qkv,
- out,
- S_dmask,
- softmax_lse,
- cu_seqlens,
- blockmask,
- dropout_p,
- max_s,
- softmax_scale,
- causal,
- ):
- dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
- dout,
- qkv,
- out,
- S_dmask,
- softmax_lse,
- cu_seqlens,
- blockmask,
- dropout_p,
- softmax_scale,
- max_s,
- causal,
- None,
- )
- # if dqkv.isnan().any() or softmax_d.isnan().any():
- # breakpoint()
- return dqkv
- class FlashBlocksparseAttnFun(torch.autograd.Function):
- @staticmethod
- def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
- # Save rng_state because the backward pass will regenerate the dropout mask
- rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
- if softmax_scale is None:
- softmax_scale = qkv.shape[-1] ** (-0.5)
- context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
- qkv,
- cu_seqlens,
- blockmask,
- dropout_p,
- max_s,
- softmax_scale,
- causal=causal,
- return_softmax=False,
- )
- ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
- ctx.dropout_p = dropout_p
- ctx.max_s = max_s
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- return context
- @staticmethod
- def backward(ctx, dout):
- qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
- if rng_state is not None:
- cur_rng_state = torch.cuda.get_rng_state()
- torch.cuda.set_rng_state(rng_state)
- # S_dmask is None, temporarily use another tensor just to get it running
- dqkv = _flash_blocksparse_attn_backward(
- dout,
- qkv,
- context,
- context,
- softmax_lse,
- cu_seqlens,
- blockmask,
- ctx.dropout_p,
- ctx.max_s,
- ctx.softmax_scale,
- ctx.causal,
- )
- if rng_state is not None:
- torch.cuda.set_rng_state(cur_rng_state)
- return dqkv, None, None, None, None, None, None, None
- # We duplicate code to return both the output and the softmax for testing
- # Returning both makes backward a bit slower, so we want to keep using the other version for speed.
- class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
- @staticmethod
- def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
- # Save rng_state because the backward pass is gonna regenerate the dropout mask
- rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
- if softmax_scale is None:
- softmax_scale = qkv.shape[-1] ** (-0.5)
- context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
- qkv,
- cu_seqlens,
- blockmask,
- dropout_p,
- max_s,
- softmax_scale,
- causal=causal,
- return_softmax=True,
- )
- ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
- ctx.dropout_p = dropout_p
- ctx.max_s = max_s
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- return context, S_dmask, softmax_lse
- @staticmethod
- def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
- qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
- if rng_state is not None:
- cur_rng_state = torch.cuda.get_rng_state()
- torch.cuda.set_rng_state(rng_state)
- dqkv = _flash_blocksparse_attn_backward(
- dout,
- qkv,
- context,
- S_dmask,
- softmax_lse,
- cu_seqlens,
- blockmask,
- ctx.dropout_p,
- ctx.max_s,
- ctx.softmax_scale,
- ctx.causal,
- )
- if rng_state is not None:
- torch.cuda.set_rng_state(cur_rng_state)
- return dqkv, None, None, None, None, None, None
- def flash_blocksparse_attn_func(
- qkv,
- cu_seqlens,
- blockmask,
- dropout_p,
- max_s,
- softmax_scale=None,
- causal=False,
- return_attn_probs=False,
- convert_mask=True,
- ):
- """dropout_p should be set to 0.0 during evaluation"""
- func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
- if convert_mask:
- blockmask = convert_blockmask(blockmask, causal=causal)
- return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)
|