123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- import torch
- import os
- from .fwd_prefill import attention_prefill_forward_triton_impl
- from .bwd_prefill import attention_prefill_backward_triton_impl
- from .fwd_decode import attention_decode_forward_triton_impl
- from .fwd_ref import attention_forward_pytorch_ref_impl
- from .bwd_ref import attention_backward_pytorch_ref_impl
- from .utils import MetaData, get_shape_from_layout, DEBUG
- USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
- def fwd(q,
- k,
- v,
- o,
- alibi_slopes,
- dropout_p,
- softmax_scale,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- return_softmax,
- gen_):
-
- if DEBUG:
- print()
- print("flash_attn_triton_amd.py::fwd")
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("o:", o)
- print("alibi_slopes:", alibi_slopes)
- print("dropout_p:", dropout_p)
- print("softmax_scale:", softmax_scale)
- print("causal:", causal)
- print("window_size_left:", window_size_left)
- print("window_size_right:", window_size_right)
- print("softcap:", softcap)
- print("softcap:", softcap)
- print("return_softmax:", return_softmax)
- if dropout_p != 0.0:
- raise ValueError("dropout is not supported on AMD's Triton Backend yet")
- if o is None:
- o = torch.empty_like(q)
- # Setup metadata
- metadata = MetaData(sm_scale=softmax_scale)
- metadata.max_seqlens_q = q.shape[1]
- metadata.max_seqlens_k = k.shape[1]
- metadata.layout = "bshd"
- if return_softmax:
- metadata.return_scores = True
- batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout)
-
- if causal:
- metadata.need_causal()
-
- if alibi_slopes is not None:
- metadata.need_alibi(alibi_slopes, batch, nheads_q)
-
- if dropout_p > 0.0:
- metadata.need_dropout(dropout_p, return_softmax)
-
- # Check arguments
- metadata.check_args(q, k, v, o)
- if USE_REF:
- if DEBUG:
- print("Using reference implementation")
- (output,
- softmax_lse,
- exp_scores,
- _,
- _,
- _,
- _) = attention_forward_pytorch_ref_impl(
- q,
- k,
- v,
- metadata.sm_scale,
- metadata.causal,
- metadata.layout,
- metadata.cu_seqlens_q,
- metadata.cu_seqlens_k,
- metadata.max_seqlens_q,
- metadata.max_seqlens_k,
- metadata.use_exp2)
- o.copy_(output)
- else:
- if DEBUG:
- print("Using Triton implementation")
- (_,
- softmax_lse,
- exp_scores,
- _,
- _,
- _,
- _,
- _,
- _) = attention_prefill_forward_triton_impl(
- q,
- k,
- v,
- o,
- metadata.sm_scale,
- metadata.alibi_slopes,
- metadata.causal,
- metadata.bias,
- metadata.dropout_p,
- metadata.layout,
- metadata.cu_seqlens_q,
- metadata.cu_seqlens_k,
- metadata.max_seqlens_q,
- metadata.max_seqlens_k,
- metadata.return_scores,
- metadata.use_exp2)
- if DEBUG:
- print("fwd outputs")
- print("o:", o, o.shape)
- print("softmax_lse:", softmax_lse, softmax_lse.shape)
- print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )
- return o, softmax_lse, exp_scores, None
- def bwd(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- alibi_slopes,
- dropout_p,
- softmax_scale,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- deterministic,
- gen_,
- rng_state,
- ):
- if DEBUG:
- print()
- print("flash_attn_triton_amd.py::bwd")
- print("dout:", dout, dout.shape)
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("out:", out, out.shape)
- print("softmax_lse:", softmax_lse, softmax_lse.shape)
- print("dq:", dq, dq.shape)
- print("dk:", dk, dk.shape)
- print("dv:", dv, dv.shape)
- print("alibi_slopes:", alibi_slopes)
- print("dropout_p:", dropout_p)
- print("out:", out)
- print("softmax_scale:", softmax_scale)
- print("causal:", causal)
- print("window_size_left:", window_size_left)
- print("window_size_right:", window_size_right)
- print("deterministic:", deterministic)
- print("gen_:", gen_)
- print("rng_state:", rng_state)
- if dropout_p != 0.0:
- raise ValueError("dropout is not supported on AMD yet")
- if USE_REF:
- if DEBUG:
- print("Using reference implementation")
- dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- softmax_scale,
- causal,
- "bshd",
- None,
- None,
- None,
- None,
- False,
- )
- dq.copy_(dq_ref)
- dk.copy_(dk_ref)
- dv.copy_(dv_ref)
- delta = delta_ref
- else:
- if DEBUG:
- print("Using Triton implementation")
- dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- softmax_scale,
- alibi_slopes,
- causal,
- "bshd",
- None,
- None,
- None,
- None,
- False,
- )
- delta = delta_triton
- if DEBUG:
- print("bwd outputs")
- print("dv:", dv, dv.shape)
- print("dk:", dk, dk.shape)
- print("dq:", dq, dq.shape)
- return dq, dk, dv, delta
- def varlen_fwd(
- q,
- k,
- v,
- o,
- cu_seqlens_q,
- cu_seqlens_k,
- seqused_k,
- leftpad_k,
- block_table_,
- alibi_slopes,\
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- zero_tensors,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- return_softmax,
- gen_):
- if DEBUG:
- print()
- print("flash_attn_triton_amd.py::varlen_fwd")
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape)
- print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape)
- print("alibi_slopes:", alibi_slopes)
- print("max_seqlen_q:", max_seqlen_q)
- print("max_seqlen_k:", max_seqlen_k)
- print("dropout_p:", dropout_p)
- print("softmax_scale:", softmax_scale)
- print("causal:", causal)
- print("window_size_left:", window_size_left)
- print("window_size_right:", window_size_right)
- print("gen_:", gen_)
- if dropout_p != 0.0:
- raise ValueError("dropout is not supported on AMD's Triton Backend yet")
-
- if o is None:
- o = torch.empty_like(q)
- # Setup metadata
- metadata = MetaData(sm_scale=softmax_scale)
- if return_softmax:
- metadata.return_scores = True
- metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
- # get shapes
- batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
- if causal:
- metadata.need_causal()
- if alibi_slopes is not None:
- metadata.need_alibi(alibi_slopes, batch, nheads_q)
-
- if dropout_p > 0.0:
- metadata.need_dropout(dropout_p, return_softmax)
-
- # Check arguments
- metadata.check_args(q, k, v, o)
- if o is None:
- o = torch.empty_like(q, dtype=v.dtype)
- if USE_REF:
- if DEBUG:
- print("Using reference implementation")
- (output,
- softmax_lse,
- exp_scores,
- _,
- _,
- _,
- _) = attention_forward_pytorch_ref_impl(
- q,
- k,
- v,
- metadata.sm_scale,
- metadata.causal,
- metadata.layout,
- metadata.cu_seqlens_q,
- metadata.cu_seqlens_k,
- metadata.max_seqlens_q,
- metadata.max_seqlens_k,
- metadata.use_exp2)
- o.copy_(output)
- else:
- if DEBUG:
- print("Using Triton implementation")
- (_,
- softmax_lse,
- exp_scores,
- _,
- _,
- _,
- _,
- _,
- _) = attention_prefill_forward_triton_impl(
- q,
- k,
- v,
- o,
- metadata.sm_scale,
- metadata.alibi_slopes,
- metadata.causal,
- metadata.bias,
- metadata.dropout_p,
- metadata.layout,
- metadata.cu_seqlens_q,
- metadata.cu_seqlens_k,
- metadata.max_seqlens_q,
- metadata.max_seqlens_k,
- metadata.return_scores,
- metadata.use_exp2)
- if DEBUG:
- print("varlen_fwd outputs")
- print("o:", o, o.shape)
- print("softmax_lse:", softmax_lse, softmax_lse.shape)
- print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )
- return o, softmax_lse, exp_scores, None
- def varlen_bwd(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- cu_seqlens_q,
- cu_seqlens_k,
- alibi_slopes,
- max_seqlen_q,
- max_seqlen_k,
- dropout_p,
- softmax_scale,
- zero_tensors,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- deterministic,
- gen_,
- rng_state,
- ):
- if DEBUG:
- print()
- print("varlen_bwd")
- print("dout:", dout, dout.shape)
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("softmax_lse:", softmax_lse, softmax_lse.shape)
- print("dq:", dq, dq.shape)
- print("dk:", dk, dk.shape)
- print("dv:", dv, dv.shape)
- print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape)
- print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape)
- print("alibi_slopes:", alibi_slopes)
- print("max_seqlen_q:", max_seqlen_q)
- print("max_seqlen_k:", max_seqlen_k)
- print("dropout_p:", dropout_p)
- print("out:", out)
- print("softmax_scale:", softmax_scale)
- print("causal:", causal)
- print("window_size_left:", window_size_left)
- print("window_size_right:", window_size_right)
- print("deterministic:", deterministic)
- print("gen_:", gen_)
- print("rng_state:", rng_state)
- if dropout_p != 0.0:
- raise ValueError("dropout is not supported on AMD yet")
- if USE_REF:
- if DEBUG:
- print("Using reference implementation")
- dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- softmax_scale,
- causal,
- "thd",
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- False,
- )
- dq.copy_(dq_ref)
- dk.copy_(dk_ref)
- dv.copy_(dv_ref)
- delta = delta_ref
- else:
- if DEBUG:
- print("Using Triton implementation")
- dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl(
- dout,
- q,
- k,
- v,
- out,
- softmax_lse,
- dq,
- dk,
- dv,
- softmax_scale,
- alibi_slopes,
- causal,
- "thd",
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- False,
- )
- delta = delta_triton
- if DEBUG:
- print("varlen_bwd outputs")
- print("delta:", delta, delta.shape)
- print("dv:", dv, dv.shape)
- print("dk:", dk, dk.shape)
- print("dq:", dq, dq.shape)
- return dq, dk, dv, delta
- def fwd_kvcache(
- q,
- k_cache,
- v_cache,
- k,
- v,
- cache_seqlens,
- rotary_cos,
- rotary_sin,
- cache_batch_idx,
- cache_leftpad,
- block_table,
- alibi_slopes,
- out,
- softmax_scale,
- causal,
- window_size_left,
- window_size_right,
- softcap,
- rotary_interleaved,
- num_splits):
- if out is None:
- out = torch.empty_like(q)
- # fill metadata
- metadata = MetaData(sm_scale=softmax_scale)
- metadata.layout = "bshd"
- metadata.max_seqlens_q = q.shape[1]
- metadata.max_seqlens_k = k_cache.shape[1]
- metadata.cache_seqlens = cache_seqlens
- metadata.cache_batch_idx = cache_batch_idx
- if k is not None and v is not None:
- metadata.new_kv = True
- metadata.seqlen_new = k.shape[1]
- metadata.k_new = k
- metadata.v_new = v
- if causal:
- metadata.need_causal()
- if alibi_slopes is not None:
- batch, _ , nheads_q, _= q.shape
- metadata.need_alibi(alibi_slopes, batch, nheads_q)
- # launch kernel
- # TODO: pass output as an arg. Maybe we are copying output which is causing slow down
- output, softmax_lse = attention_decode_forward_triton_impl(
- q,
- k_cache,
- v_cache,
- metadata.sm_scale,
- metadata.causal,
- metadata.alibi_slopes,
- metadata.layout,
- metadata.cache_seqlens,
- metadata.cache_batch_idx,
- metadata.new_kv,
- metadata.k_new,
- metadata.v_new,
- )
- return output, softmax_lse
|