123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370 |
- # Copyright (c) 2023, Tri Dao.
- from typing import Optional, Union
- import torch
- import torch.nn as nn
- # isort: off
- # We need to import the CUDA kernels after importing torch
- import flashattn_hopper_cuda
- # isort: on
- def maybe_contiguous(x):
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
- def _flash_attn_forward(q, k, v, softmax_scale, causal):
- # q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
- q,
- k,
- v,
- None,
- softmax_scale,
- causal,
- )
- return out, q, k, v, out_padded, softmax_lse, S_dmask
- def _flash_attn_backward(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- softmax_scale,
- causal
- ):
- # dq, dk, dv are allocated by us so they should already be contiguous
- #dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
- dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- softmax_scale,
- causal,
- )
- return dq, dk, dv, softmax_d
- def _flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal,
- ):
- maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
- q,
- k,
- v,
- None,
- cu_seqlens_q,
- cu_seqlens_k,
- None,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal,
- )
- # if out.isnan().any() or softmax_lse.isnan().any():
- # breakpoint()
- return out, q, k, v, out_padded, softmax_lse
- def _flash_attn_varlen_backward(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal,
- ):
- maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
- # dq, dk, dv are allocated by us so they should already be contiguous
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
- (
- dq,
- dk,
- dv,
- softmax_d,
- ) = _get_fa_module().varlen_bwd(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal,
- )
- # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
- # breakpoint()
- return dq, dk, dv, softmax_d
- class FlashAttnFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- softmax_scale,
- causal,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
- q,
- k,
- v,
- softmax_scale,
- causal
- )
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- return out, softmax_lse
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse = ctx.saved_tensors
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
- _flash_attn_backward(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- ctx.softmax_scale,
- ctx.causal,
- )
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
- dk = dk[..., : dout.shape[-1]]
- dv = dv[..., : dout.shape[-1]]
- return dq, dk, dv, None, None
- class FlashAttnVarlenFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal=causal,
- )
- ctx.save_for_backward(
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
- )
- ctx.max_seqlen_q = max_seqlen_q
- ctx.max_seqlen_k = max_seqlen_k
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- return out, softmax_lse
- @staticmethod
- def backward(ctx, dout, *args):
- # TODO: Uncomment these when var-seq-len is supported in bwd kernel.
- # q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
- # dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
- # _flash_attn_varlen_backward(
- # dout,
- # q,
- # k,
- # v,
- # out,
- # softmax_lse,
- # dq,
- # dk,
- # dv,
- # cu_seqlens_q,
- # cu_seqlens_k,
- # ctx.max_seqlen_q,
- # ctx.max_seqlen_k,
- # ctx.softmax_scale,
- # ctx.causal,
- # )
- # dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
- # dk = dk[..., : dout.shape[-1]]
- # dv = dv[..., : dout.shape[-1]]
- # return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
- return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
- def flash_attn_func(
- q,
- k,
- v,
- softmax_scale=None,
- causal=False,
- ):
- """dropout_p should be set to 0.0 during evaluation
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
- Arguments:
- q: (batch_size, seqlen, nheads, headdim)
- k: (batch_size, seqlen, nheads_k, headdim)
- v: (batch_size, seqlen, nheads_k, headdim)
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
- is added to the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (batch_size, seqlen, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnFunc.apply(
- q,
- k,
- v,
- softmax_scale,
- causal,
- )
- def flash_attn_varlen_func(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale=None,
- causal=False,
- ):
- """
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- Arguments:
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
- k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
- v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into q.
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into kv.
- max_seqlen_q: int. Maximum query sequence length in the batch.
- max_seqlen_k: int. Maximum key sequence length in the batch.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- """
- return FlashAttnVarlenFunc.apply(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- softmax_scale,
- causal,
- )
|