123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569 |
- # Copyright (c) 2023, Tri Dao.
- from typing import Optional, Sequence, Tuple, Union
- import torch
- import torch.nn as nn
- # isort: off
- # We need to import the CUDA kernels after importing torch
- import flash_attn_2_cuda as flash_attn_cuda
- # isort: on
- def maybe_contiguous(x):
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
- def _get_block_size_n(device, head_dim, is_dropout, is_causal):
- # This should match the block sizes in the CUDA kernel
- assert head_dim <= 256
- major, minor = torch.cuda.get_device_capability(device)
- is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
- is_sm80 = major == 8 and minor == 0
- is_sm90 = major == 9 and minor == 0
- if head_dim <= 32:
- return 128
- if head_dim <= 64:
- return 128 if not is_dropout else 64
- elif head_dim <= 96:
- return 64
- elif head_dim <= 128:
- if is_sm8x:
- return 64 if (not is_dropout and is_causal) else 32
- else:
- return 64 if not is_dropout else 32
- elif head_dim <= 160:
- if is_sm8x:
- return 64
- else:
- return 32
- elif head_dim <= 192:
- return 64
- elif head_dim <= 224:
- return 64
- elif head_dim <= 256:
- return 64
- def round_multiple(x, m):
- return (x + m - 1) // m * m
- # torch.compile() support is only enabled for pytorch >= 2.4
- # The reason for this is that we are using the new custom_op and register_fake
- # APIs, which support inplace modification of inputs in the function itself
- if torch.__version__ >= "2.4.0":
- _torch_custom_op_wrapper = torch.library.custom_op
- _torch_register_fake_wrapper = torch.library.register_fake
- else:
- def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
- def wrap(func):
- return func
- if fn is None:
- return wrap
- return fn
- def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
- def wrap(func):
- return func
- if fn is None:
- return wrap
- return fn
- _torch_custom_op_wrapper = noop_custom_op_wrapper
- _torch_register_fake_wrapper = noop_register_fake_wrapper
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
- def _flash_attn_forward(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int,
- window_size_right: int,
- softcap: float,
- alibi_slopes: Optional[torch.Tensor],
- return_softmax: bool
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
- q,
- k,
- v,
- None,
- alibi_slopes,
- dropout_p,
- softmax_scale,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- return_softmax,
- None,
- )
- return out, softmax_lse, S_dmask, rng_state
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
- def _flash_attn_forward_fake(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int,
- window_size_right: int,
- softcap: float,
- alibi_slopes: Optional[torch.Tensor],
- return_softmax: bool
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- batch_size, seqlen_q, num_heads, head_size = q.shape
- seqlen_k = k.shape[1]
- out = torch.empty_like(q)
- softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
- p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
- if return_softmax:
- p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
- rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
- return out, softmax_lse, p, rng_state
- if torch.__version__ >= "2.4.0":
- _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
- else:
- _wrapped_flash_attn_forward = _flash_attn_forward
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
- def _flash_attn_varlen_forward(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- cu_seqlens_q: torch.Tensor,
- cu_seqlens_k: torch.Tensor,
- max_seqlen_q: int,
- max_seqlen_k: int,
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int = -1,
- window_size_right: int = -1,
- softcap: float = 0.0,
- alibi_slopes: Optional[torch.Tensor] = None,
- return_softmax: bool = False,
- block_table: Optional[torch.Tensor] = None,
- leftpad_k: Optional[torch.Tensor] = None,
- seqused_k: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
- q,
- k,
- v,
- None,
- cu_seqlens_q,
- cu_seqlens_k,
- seqused_k,
- leftpad_k,
- block_table,
- alibi_slopes,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- False,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- return_softmax,
- None,
- )
- # if out.isnan().any() or softmax_lse.isnan().any():
- # breakpoint()
- return out, softmax_lse, S_dmask, rng_state
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
- def _flash_attn_varlen_forward_fake(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- cu_seqlens_q: torch.Tensor,
- cu_seqlens_k: torch.Tensor,
- max_seqlen_q: int,
- max_seqlen_k: int,
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int = -1,
- window_size_right: int = -1,
- softcap: float = 0.0,
- alibi_slopes: Optional[torch.Tensor] = None,
- return_softmax: bool = False,
- block_table: Optional[torch.Tensor] = None,
- leftpad_k: Optional[torch.Tensor] = None,
- seqused_k: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- paged_kv = block_table is not None
- batch_size = cu_seqlens_q.numel() - 1
- total_q, num_heads, _ = q.shape
-
- out = torch.empty_like(q)
- softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
- p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
- seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
- seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
- if return_softmax:
- p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
- rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
- return out, softmax_lse, p, rng_state
- if torch.__version__ >= "2.4.0":
- _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
- else:
- _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
- def _flash_attn_backward(
- dout: torch.Tensor,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- out: torch.Tensor,
- softmax_lse: torch.Tensor,
- dq: Optional[torch.Tensor],
- dk: Optional[torch.Tensor],
- dv: Optional[torch.Tensor],
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int,
- window_size_right: int,
- softcap: float,
- alibi_slopes: Optional[torch.Tensor],
- deterministic: bool,
- rng_state: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- # dq, dk, dv are allocated by us so they should already be contiguous
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
- (
- dq,
- dk,
- dv,
- softmax_d,
- ) = flash_attn_cuda.bwd(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- alibi_slopes,
- dropout_p,
- softmax_scale,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- deterministic,
- None,
- rng_state,
- )
- return softmax_d
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
- def _flash_attn_backward_fake(
- dout: torch.Tensor,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- out: torch.Tensor,
- softmax_lse: torch.Tensor,
- dq: Optional[torch.Tensor],
- dk: Optional[torch.Tensor],
- dv: Optional[torch.Tensor],
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int,
- window_size_right: int,
- softcap: float,
- alibi_slopes: Optional[torch.Tensor],
- deterministic: bool,
- rng_state: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
- if dq is None:
- dq = torch.empty_like(q)
- if dk is None:
- dk = torch.empty_like(k)
- if dv is None:
- dv = torch.empty_like(v)
- batch_size, seqlen_q, num_heads, _ = q.shape
- softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
-
- return softmax_d
- if torch.__version__ >= "2.4.0":
- _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
- else:
- _wrapped_flash_attn_backward = _flash_attn_backward
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
- def _flash_attn_varlen_backward(
- dout: torch.Tensor,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- out: torch.Tensor,
- softmax_lse: torch.Tensor,
- dq: Optional[torch.Tensor],
- dk: Optional[torch.Tensor],
- dv: Optional[torch.Tensor],
- cu_seqlens_q: torch.Tensor,
- cu_seqlens_k: torch.Tensor,
- max_seqlen_q: int,
- max_seqlen_k: int,
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int,
- window_size_right: int,
- softcap: float,
- alibi_slopes: Optional[torch.Tensor],
- deterministic: bool,
- rng_state: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- # dq, dk, dv are allocated by us so they should already be contiguous
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
- (
- dq,
- dk,
- dv,
- softmax_d,
- ) = flash_attn_cuda.varlen_bwd(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- cu_seqlens_q,
- cu_seqlens_k,
- alibi_slopes,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- False,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- deterministic,
- None,
- rng_state,
- )
- # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
- # breakpoint()
- return softmax_d
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
- def _flash_attn_varlen_backward_fake(
- dout: torch.Tensor,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- out: torch.Tensor,
- softmax_lse: torch.Tensor,
- dq: Optional[torch.Tensor],
- dk: Optional[torch.Tensor],
- dv: Optional[torch.Tensor],
- cu_seqlens_q: torch.Tensor,
- cu_seqlens_k: torch.Tensor,
- max_seqlen_q: int,
- max_seqlen_k: int,
- dropout_p: float,
- softmax_scale: float,
- causal: bool,
- window_size_left: int,
- window_size_right: int,
- softcap: float,
- alibi_slopes: Optional[torch.Tensor],
- deterministic: bool,
- rng_state: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
- batch_size = cu_seqlens_q.numel() - 1
- total_q, num_heads, _ = q.shape
- if dq is None:
- dq = torch.empty_like(q)
- if dk is None:
- dk = torch.empty_like(k)
- if dv is None:
- dv = torch.empty_like(v)
- softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
-
- return softmax_d
- if torch.__version__ >= "2.4.0":
- _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
- else:
- _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- qkv,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- ):
- if softmax_scale is None:
- softmax_scale = qkv.shape[-1] ** (-0.5)
- q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
- head_size_og = q.size(3)
- if head_size_og % 8 != 0:
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- )
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
- ctx.dropout_p = dropout_p
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- out = out_padded[..., :head_size_og]
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
- head_size_og = dout.size(3)
- dout_padded = dout
- if head_size_og % 8 != 0:
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
- _wrapped_flash_attn_backward(
- dout_padded,
- q,
- k,
- v,
- out,
- softmax_lse,
- dqkv[:, :, 0],
- dqkv[:, :, 1],
- dqkv[:, :, 2],
- ctx.dropout_p,
- ctx.softmax_scale,
- ctx.causal,
- ctx.window_size[0],
- ctx.window_size[1],
- ctx.softcap,
- ctx.alibi_slopes,
- ctx.deterministic,
- rng_state=rng_state,
- )
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
- return dqkv, None, None, None, None, None, None, None, None
- class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- qkv,
- cu_seqlens,
- max_seqlen,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- ):
- if softmax_scale is None:
- softmax_scale = qkv.shape[-1] ** (-0.5)
- q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
- head_size_og = q.size(2)
- if head_size_og % 8 != 0:
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens,
- cu_seqlens,
- max_seqlen,
- max_seqlen,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- block_table=None,
- )
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
- ctx.dropout_p = dropout_p
- ctx.max_seqlen = max_seqlen
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- out = out_padded[..., :head_size_og]
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
- head_size_og = dout.size(2)
- dout_padded = dout
- if head_size_og % 8 != 0:
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
- _wrapped_flash_attn_varlen_backward(
- dout_padded,
- q,
- k,
- v,
- out,
- softmax_lse,
- dqkv[:, 0],
- dqkv[:, 1],
- dqkv[:, 2],
- cu_seqlens,
- cu_seqlens,
- ctx.max_seqlen,
- ctx.max_seqlen,
- ctx.dropout_p,
- ctx.softmax_scale,
- ctx.causal,
- ctx.window_size[0],
- ctx.window_size[1],
- ctx.softcap,
- ctx.alibi_slopes,
- ctx.deterministic,
- rng_state=rng_state,
- )
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
- return dqkv, None, None, None, None, None, None, None, None, None, None
- class FlashAttnKVPackedFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- kv,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
- head_size_og = q.size(3)
- if head_size_og % 8 != 0:
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- )
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
- ctx.dropout_p = dropout_p
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- out = out_padded[..., :head_size_og]
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
- dq = torch.empty_like(q)
- kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
- dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
- head_size_og = dout.size(3)
- dout_padded = dout
- if head_size_og % 8 != 0:
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
- _wrapped_flash_attn_backward(
- dout_padded,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dkv[:, :, 0],
- dkv[:, :, 1],
- ctx.dropout_p,
- ctx.softmax_scale,
- ctx.causal,
- ctx.window_size[0],
- ctx.window_size[1],
- ctx.softcap,
- ctx.alibi_slopes,
- ctx.deterministic,
- rng_state=rng_state,
- )
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
- dkv = dkv[..., : dout.shape[-1]]
- return dq, dkv, None, None, None, None, None, None, None, None
- class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- kv,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- k, v = kv[:, 0].detach(), kv[:, 1].detach()
- head_size_og = q.size(2)
- if head_size_og % 8 != 0:
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- block_table=None,
- )
- ctx.save_for_backward(
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
- )
- ctx.dropout_p = dropout_p
- ctx.max_seqlen_q = max_seqlen_q
- ctx.max_seqlen_k = max_seqlen_k
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- out = out_padded[..., :head_size_og]
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
- dq = torch.empty_like(q)
- kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
- dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
- head_size_og = dout.size(2)
- dout_padded = dout
- if head_size_og % 8 != 0:
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
- _wrapped_flash_attn_varlen_backward(
- dout_padded,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dkv[:, 0],
- dkv[:, 1],
- cu_seqlens_q,
- cu_seqlens_k,
- ctx.max_seqlen_q,
- ctx.max_seqlen_k,
- ctx.dropout_p,
- ctx.softmax_scale,
- ctx.causal,
- ctx.window_size[0],
- ctx.window_size[1],
- ctx.softcap,
- ctx.alibi_slopes,
- ctx.deterministic,
- rng_state=rng_state,
- )
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
- dkv = dkv[..., : dout.shape[-1]]
- return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None
- class FlashAttnFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- head_size_og = q.size(3)
- if head_size_og % 8 != 0:
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- )
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
- ctx.dropout_p = dropout_p
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- out = out_padded[..., :head_size_og]
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
- head_size_og = dout.size(3)
- dout_padded = dout
- if head_size_og % 8 != 0:
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
- _wrapped_flash_attn_backward(
- dout_padded,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- ctx.dropout_p,
- ctx.softmax_scale,
- ctx.causal,
- ctx.window_size[0],
- ctx.window_size[1],
- ctx.softcap,
- ctx.alibi_slopes,
- ctx.deterministic,
- rng_state=rng_state,
- )
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
- dk = dk[..., : dout.shape[-1]]
- dv = dv[..., : dout.shape[-1]]
- return dq, dk, dv, None, None, None, None, None, None, None, None
- class FlashAttnVarlenFunc(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_softmax,
- block_table,
- ):
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- head_size_og = q.size(2)
- if head_size_og % 8 != 0:
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal=causal,
- window_size_left=window_size[0],
- window_size_right=window_size[1],
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- return_softmax=return_softmax and dropout_p > 0,
- block_table=block_table,
- )
- ctx.save_for_backward(
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
- )
- ctx.dropout_p = dropout_p
- ctx.max_seqlen_q = max_seqlen_q
- ctx.max_seqlen_k = max_seqlen_k
- ctx.softmax_scale = softmax_scale
- ctx.causal = causal
- ctx.window_size = window_size
- ctx.softcap = softcap
- ctx.alibi_slopes = alibi_slopes
- ctx.deterministic = deterministic
- out = out_padded[..., :head_size_og]
- return out if not return_softmax else (out, softmax_lse, S_dmask)
- @staticmethod
- def backward(ctx, dout, *args):
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
- head_size_og = dout.size(2)
- dout_padded = dout
- if head_size_og % 8 != 0:
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
- _wrapped_flash_attn_varlen_backward(
- dout_padded,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- cu_seqlens_q,
- cu_seqlens_k,
- ctx.max_seqlen_q,
- ctx.max_seqlen_k,
- ctx.dropout_p,
- ctx.softmax_scale,
- ctx.causal,
- ctx.window_size[0],
- ctx.window_size[1],
- ctx.softcap,
- ctx.alibi_slopes,
- ctx.deterministic,
- rng_state=rng_state,
- )
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
- dk = dk[..., : dout.shape[-1]]
- dv = dv[..., : dout.shape[-1]]
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None
- def flash_attn_qkvpacked_func(
- qkv,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # <=0.0 means deactivate
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- ):
- """dropout_p should be set to 0.0 during evaluation
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
- of the gradients of Q, K, V.
- For multi-query and grouped-query attention (MQA/GQA), please see
- flash_attn_kvpacked_func and flash_attn_func.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
- Arguments:
- qkv: (batch_size, seqlen, 3, nheads, headdim)
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- softcap: float. Anything > 0 activates softcapping attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
- the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (batch_size, seqlen, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnQKVPackedFunc.apply(
- qkv,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- )
- def flash_attn_kvpacked_func(
- q,
- kv,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- ):
- """dropout_p should be set to 0.0 during evaluation
- If K, V are already stacked into 1 tensor, this function will be faster than
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
- of the gradients of K, V.
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
- Arguments:
- q: (batch_size, seqlen, nheads, headdim)
- kv: (batch_size, seqlen, 2, nheads_k, headdim)
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- softcap: float. Anything > 0 activates softcapping attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
- is added to the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (batch_size, seqlen, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnKVPackedFunc.apply(
- q,
- kv,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- )
- def flash_attn_func(
- q,
- k,
- v,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- ):
- """dropout_p should be set to 0.0 during evaluation
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
- Arguments:
- q: (batch_size, seqlen, nheads, headdim)
- k: (batch_size, seqlen, nheads_k, headdim)
- v: (batch_size, seqlen, nheads_k, headdim)
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
- is added to the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (batch_size, seqlen, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnFunc.apply(
- q,
- k,
- v,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- )
- def flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens,
- max_seqlen,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- ):
- """dropout_p should be set to 0.0 during evaluation
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
- calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
- of the gradients of Q, K, V.
- For multi-query and grouped-query attention (MQA/GQA), please see
- flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
- Arguments:
- qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into qkv.
- max_seqlen: int. Maximum sequence length in the batch.
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- softcap: float. Anything > 0 activates softcapping attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
- is added to the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (total, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnVarlenQKVPackedFunc.apply(
- qkv,
- cu_seqlens,
- max_seqlen,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- )
- def flash_attn_varlen_kvpacked_func(
- q,
- kv,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- ):
- """dropout_p should be set to 0.0 during evaluation
- If K, V are already stacked into 1 tensor, this function will be faster than
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
- of the gradients of K, V.
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
- Arguments:
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
- kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into q.
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into kv.
- max_seqlen_q: int. Maximum query sequence length in the batch.
- max_seqlen_k: int. Maximum key sequence length in the batch.
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- softcap: float. Anything > 0 activates softcapping attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
- is added to the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (total, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnVarlenKVPackedFunc.apply(
- q,
- kv,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- )
- def flash_attn_varlen_func(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p=0.0,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- alibi_slopes=None,
- deterministic=False,
- return_attn_probs=False,
- block_table=None,
- ):
- """dropout_p should be set to 0.0 during evaluation
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
- Arguments:
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
- k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
- v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into q.
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into kv.
- max_seqlen_q: int. Maximum query sequence length in the batch.
- max_seqlen_k: int. Maximum key sequence length in the batch.
- dropout_p: float. Dropout probability.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- softcap: float. Anything > 0 activates softcapping attention.
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
- is added to the attention score of query i and key j.
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
- which is slightly slower and uses more memory. The forward pass is always deterministic.
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
- testing only. The returned probabilities are not guaranteed to be correct
- (they might not have the right scaling).
- Return:
- out: (total, nheads, headdim).
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
- The output of softmax (possibly with different scaling). It also encodes the dropout
- pattern (negative means that location was dropped, nonnegative means it was kept).
- """
- return FlashAttnVarlenFunc.apply(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- causal,
- window_size,
- softcap,
- alibi_slopes,
- deterministic,
- return_attn_probs,
- block_table,
- )
- def flash_attn_with_kvcache(
- q,
- k_cache,
- v_cache,
- k=None,
- v=None,
- rotary_cos=None,
- rotary_sin=None,
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
- cache_batch_idx: Optional[torch.Tensor] = None,
- cache_leftpad: Optional[torch.Tensor] = None,
- block_table: Optional[torch.Tensor] = None,
- softmax_scale=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite context window
- softcap=0.0, # 0.0 means deactivated
- rotary_interleaved=True,
- alibi_slopes=None,
- num_splits=0,
- return_softmax_lse=False,
- ):
- """
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
- the previous step, and update them with the new keys/values from the current step, and do
- attention with the updated cache, all in 1 kernel.
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
- 1 1 1 1 0
- 1 1 1 1 1
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
- 0 0
- 0 0
- 0 0
- 1 0
- 1 1
- If the row of the mask is all zero, the output will be zero.
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
- will only attend to keys between
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
- Note: Does not support backward pass.
- Arguments:
- q: (batch_size, seqlen, nheads, headdim)
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
- page_block_size must be a multiple of 256.
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
- k with k_cache, starting at the indices specified by cache_seqlens.
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
- KV cache.
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
- If the indices are not distinct, and k and v are provided, the values updated in the cache
- might come from any of the duplicate indices.
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
- block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
- softmax_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
- softcap: float. Anything > 0 activates softcapping attention.
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
- (i.e. GPT-NeoX style).
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
- is added to the attention score of query i and key j.
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
- to automatically determine the number of splits.
- Don't change this unless you know what you are doing.
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
- Return:
- out: (batch_size, seqlen, nheads, headdim).
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
- normalization factor).
- """
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
- if softmax_scale is None:
- softmax_scale = q.shape[-1] ** (-0.5)
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
- cache_seqlens = torch.full(
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
- )
- cache_seqlens = maybe_contiguous(cache_seqlens)
- cache_batch_idx = maybe_contiguous(cache_batch_idx)
- block_table = maybe_contiguous(block_table)
- out, softmax_lse = flash_attn_cuda.fwd_kvcache(
- q,
- k_cache,
- v_cache,
- k,
- v,
- cache_seqlens,
- rotary_cos,
- rotary_sin,
- cache_batch_idx,
- cache_leftpad,
- block_table,
- alibi_slopes,
- None,
- softmax_scale,
- causal,
- window_size[0],
- window_size[1],
- softcap,
- rotary_interleaved,
- num_splits,
- )
- return (out, softmax_lse) if return_softmax_lse else out
|