# Copyright (c) 2023, Tri Dao. from typing import Optional, Union import torch import torch.nn as nn # isort: off # We need to import the CUDA kernels after importing torch import flashattn_hopper_cuda # isort: on def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _flash_attn_forward(q, k, v, softmax_scale, causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k = [maybe_contiguous(x) for x in (q, k)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.fwd( q, k, v, None, softmax_scale, causal, q_descale, k_descale, v_descale, window_size[0], window_size[1], sink_token_length, softcap, num_splits, pack_gqa ) return out, q, k, v, out_padded, softmax_lse def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), softcap=0.0, num_splits=1, pack_gqa=None): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.fwd_varlen( q, k, v, None, cu_seqlens_q, cu_seqlens_k, None, None, max_seqlen_q, max_seqlen_k, softmax_scale, causal, q_descale, k_descale, v_descale, window_size[0], window_size[1], softcap, num_splits, pack_gqa ) # breakpoint() return out, q, k, v, out_padded, softmax_lse def _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, softmax_scale, causal, window_size=(-1, -1), sink_token_length=0, softcap=0.0, deterministic=False ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, softmax_scale, causal, window_size[0], window_size[1], sink_token_length, softcap, deterministic, ) return dq, dk, dv, softmax_d def _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dq, dk, dv, softmax_scale, causal, window_size=(-1, -1), softcap=0.0, deterministic=False ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd_varlen( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, None, None, max_seqlen_q, max_seqlen_k, softmax_scale, causal, window_size[0], window_size[1], softcap, deterministic, ) return dq, dk, dv, softmax_d class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, qkv, softmax_scale, causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, softcap=0.0, deterministic=False, num_heads_q=None, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) if qkv.dim() == 5: assert qkv.shape[-3] == 3 q, k, v = qkv.unbind(dim=-3) else: assert qkv.dim() == 4 assert num_heads_q is not None num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert num_heads_k * 2 + num_heads_q == qkv.shape[2] q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( q, k, v, softmax_scale, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.ndim = qkv.dim() # return out, softmax_lse return out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors if ctx.ndim == 5: qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dq, dk, dv = dqkv.unbind(dim=-3) else: num_heads_q = q.shape[2] num_heads_k = k.shape[2] qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size, ctx.sink_token_length, ctx.softcap, ctx.deterministic, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension return dqkv, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, softmax_scale, causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( q, k, v, softmax_scale, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic return out, softmax_lse @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size, ctx.sink_token_length, ctx.softcap, ctx.deterministic, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic return out, softmax_lse @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size, ctx.softcap, ctx.deterministic, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( qkv, softmax_scale=None, causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, softcap=0.0, deterministic=False, num_heads_q=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. For multi-query and grouped-query attention (MQA/GQA), please see flash_attn_kvpacked_func and flash_attn_func. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnQKVPackedFunc.apply( qkv, softmax_scale, causal, q_descale, k_descale, v_descale, window_size, sink_token_length, softcap, deterministic, num_heads_q, ) def flash_attn_func( q, k, v, softmax_scale=None, causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ return FlashAttnFunc.apply( q, k, v, softmax_scale, causal, q_descale, k_descale, v_descale, window_size, sink_token_length, softcap, num_splits, pack_gqa, deterministic, ) def flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale=None, causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False ): return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, q_descale, k_descale, v_descale, window_size, softcap, num_splits, pack_gqa, deterministic, ) def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): return flashattn_hopper_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype) def flash_attn_with_kvcache( q, k_cache, v_cache, k=None, v=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window sink_token_length=0, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed return_softmax_lse=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from k and v. This is useful for incremental decoding: you can pass in the cached keys/values from the previous step, and update them with the new keys/values from the current step, and do attention with the updated cache, all in 1 kernel. If you pass in k / v, you must make sure that the cache is large enough to hold the new values. For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Note: Does not support backward pass. Arguments: q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) page_block_size must be a multiple of 256. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style). num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ assert sink_token_length == 0 assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" q, k, v = [maybe_contiguous(x) for x in (q, k, v)] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) page_table = maybe_contiguous(page_table) cu_seqlens_q = maybe_contiguous(cu_seqlens_q) out, softmax_lse, *rest = flashattn_hopper_cuda.fwd_kvcache( q, k_cache, v_cache, k, v, None, # out cache_seqlens, rotary_cos, rotary_sin, cache_batch_idx, cache_leftpad, page_table, cu_seqlens_q, max_seqlen_q, softmax_scale, causal, None, None, None, # qkv_descale window_size[0], window_size[1], sink_token_length, softcap, rotary_interleaved, num_splits, pack_gqa ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out