123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- from typing import Optional, Union
- import torch
- import aphrodite._custom_ops as ops
- def maybe_contiguous(x):
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
- def _flash_attn_forward(
- q, k, v, dropout_p, softmax_scale, causal,
- window_size, softcap, alibi_slopes,
- return_softmax, *, out=None
- ):
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- (out, q, k, v, out_padded, softmax_lse,
- S_dmask, rng_state) = ops.fwd(
- q=q,
- k=k,
- v=v,
- out=out,
- alibi_slopes=alibi_slopes,
- dropout_p=dropout_p,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- return_softmax=return_softmax,
- gen=None,
- ) # type: ignore
- return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
- def _flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- return_softmax,
- block_table,
- *,
- out=None
- ):
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- (out, q, k, v, out_padded, softmax_lse,
- S_dmask, rng_state) = ops.varlen_fwd(
- q=q,
- k=k,
- v=v,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_q,
- max_seqlen_k=max_seqlen_k,
- dropout_p=dropout_p,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- block_table=block_table,
- return_softmax=return_softmax,
- gen=None,
- out=out,
- seqused_k=None,
- zero_tensors=False,
- ) # type: ignore
- return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
- class FlashAttnFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- out=None,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- (out, q, k, v, out_padded, softmax_lse,
- S_dmask, rng_state) = _flash_attn_forward(
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size=window_size,
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- out=out,
- )
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
- ctx.dropout_p = dropout_p
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- class FlashAttnVarlenFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- block_table,
- out=None,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- (out, q, k, v, out_padded, softmax_lse,
- S_dmask, rng_state) = _flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size=window_size,
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- block_table=block_table,
- out=out,
- )
- ctx.save_for_backward(
- q, k, v, out_padded, softmax_lse, cu_seqlens_q,
- cu_seqlens_k, rng_state
- )
- ctx.dropout_p = dropout_p
- ctx.max_seqlen_q = max_seqlen_q
- ctx.max_seqlen_k = max_seqlen_k
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- return out if not return_softmax else (out, softmax_lse, S_dmask)
-
- def flash_attn_varlen_func(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- block_table=None,
- *,
- out=None,
- ):
- return FlashAttnVarlenFunc.apply(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- block_table,
- out,
- )
- def flash_attn_with_kvcache(
- q,
- k_cache,
- v_cache,
- k=None,
- v=None,
- rotary_cos=None,
- rotary_sin=None,
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
- cache_batch_idx: Optional[torch.Tensor] = None,
- block_table: Optional[torch.Tensor] = None,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- rotary_interleaved=True,
- alibi_slopes=None,
- num_splits=0,
- return_softmax_lse=False,
- *,
- out=None,
- ):
- assert k_cache.stride(-1) == 1, (
- "k_cache must have contiguous last dimension"
- )
- assert v_cache.stride(-1) == 1, (
- "v_cache must have contiguous last dimension"
- )
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
- cache_seqlens = torch.full(
- (k_cache.shape[0],), cache_seqlens,
- dtype=torch.int32, device=k_cache.device
- )
- cache_seqlens = maybe_contiguous(cache_seqlens)
- cache_batch_idx = maybe_contiguous(cache_batch_idx)
- block_table = maybe_contiguous(block_table)
- out, softmax_lse = ops.fwd_kvcache(
- q=q,
- kcache=k_cache,
- vcache=v_cache,
- k=k,
- v=v,
- seqlens_k=cache_seqlens,
- rotary_cos=rotary_cos,
- rotary_sin=rotary_sin,
- cache_batch_idx=cache_batch_idx,
- block_table=block_table,
- alibi_slopes=alibi_slopes,
- out=out,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- rotary_interleaved=rotary_interleaved,
- num_splits=num_splits,
- ) # type: ignore
- return (out, softmax_lse) if return_softmax_lse else out
|