# Copyright (c) 2023, Tri Dao. from typing import Optional, Union import torch import torch.nn as nn # isort: off # We need to import the CUDA kernels after importing torch import flashattn_hopper_cuda # isort: on def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, k, v, None, softmax_scale, descale_q, descale_k, descale_v, causal, ) return out, q, k, v, out_padded, softmax_lse, S_dmask def _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, softmax_scale, causal, deterministic=False ): # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, softmax_scale, causal, deterministic, ) return dq, dk, dv, softmax_d def _flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, seqused_q=None, seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd( q, k, v, None, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() return out, q, k, v, out_padded, softmax_lse def _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, deterministic=False, seqused_q=None, seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( dq, dk, dv, softmax_d, *rest, ) = flashattn_hopper_cuda.varlen_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, deterministic, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, softmax_scale, causal, deterministic=False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( q, k, v, softmax_scale, causal ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.deterministic = deterministic return out, softmax_lse @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.deterministic, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, deterministic=False, seqused_q=None, seqused_k=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal=causal, seqused_q=seqused_q, seqused_k=seqused_k, ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k ) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.deterministic = deterministic return out, softmax_lse @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.softmax_scale, ctx.causal, ctx.deterministic, seqused_q, seqused_k, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None def flash_attn_func( q, k, v, softmax_scale=None, causal=False, deterministic=False ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnFunc.apply( q, k, v, softmax_scale, causal, deterministic, ) def flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale=None, causal=False, deterministic=False, seqused_q=None, seqused_k=None, ): """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of query and output tokens in each sequence. seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of key and value tokens in each sequence. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, deterministic, seqused_q, seqused_k, )