123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820 |
- #!/usr/bin/env python
- """
- Fused Attention
- ===============
- This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
- (https://tridao.me/publications/flash2/flash2.pdf)
- Credits: OpenAI kernel team, AMD ML Frameworks Triton team
- Features supported:
- 1) Fwd with causal masking
- 2) Any sequence lengths without padding (currently fwd kernel only)
- 3) Support for different sequence lengths for q and k
- 4) Nested tensor API currently does not support dropout or bias.
- Not currently supported:
- 1) Non power of two head dims
- """
- import torch
- import triton
- import triton.language as tl
- torch_dtype: tl.constexpr = torch.float16
- @triton.jit
- def cdiv_fn(x, y):
- return (x + y - 1) // y
- @triton.jit
- def max_fn(x, y):
- return tl.math.max(x, y)
- @triton.jit
- def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
- ms = tl.arange(0, m)
- ns = tl.arange(0, n)
- return philox_offset + ms[:, None] * stride + ns[None, :]
- @triton.jit
- def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
- rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
- stride).to(tl.uint32)
- # TODO: use tl.randint for better performance
- return tl.rand(philox_seed, rng_offsets)
- @triton.jit
- def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
- rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
- stride)
- rng_keep = rng_output > dropout_p
- return rng_keep
- @triton.jit
- def load_fn(block_ptr, first, second, pad):
- if first and second:
- tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
- elif first:
- tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
- elif second:
- tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
- else:
- tensor = tl.load(block_ptr)
- return tensor
- @triton.jit
- def _attn_fwd_inner(
- acc,
- l_i,
- m_i,
- q,
- K_block_ptr,
- V_block_ptr,
- start_m,
- actual_seqlen_k,
- dropout_p,
- philox_seed,
- batch_philox_offset,
- encoded_softmax_block_ptr,
- block_min,
- block_max,
- offs_n_causal,
- masked_blocks,
- n_extra_tokens,
- bias_ptr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- OFFS_M: tl.constexpr,
- OFFS_N: tl.constexpr,
- PRE_LOAD_V: tl.constexpr,
- MASK_STEPS: tl.constexpr,
- ENABLE_DROPOUT: tl.constexpr,
- RETURN_ENCODED_SOFTMAX: tl.constexpr,
- PADDED_HEAD: tl.constexpr,
- ):
- # loop over k, v, and update accumulator
- for start_n in range(block_min, block_max, BLOCK_N):
- # For padded blocks, we will overrun the tensor size if
- # we load all BLOCK_N. For others, the blocks are all within range.
- k = load_fn(
- K_block_ptr,
- PADDED_HEAD,
- MASK_STEPS and (n_extra_tokens != 0),
- "zero",
- )
- if PRE_LOAD_V:
- v = load_fn(
- V_block_ptr,
- MASK_STEPS and (n_extra_tokens != 0),
- PADDED_HEAD,
- "zero",
- )
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- # We start from end of seqlen_k so only the first iteration would need
- # to be checked for padding if it is not a multiple of block_n
- # TODO: This can be optimized to only be true for the padded block.
- if MASK_STEPS: # noqa: SIM102
- # If this is the last block / iteration, we want to
- # mask if the sequence length is not a multiple of block size
- # a solution is to always do BLOCK_M // BLOCK_N + 1 steps
- # if not is_modulo_mn. last step might get wasted but that is okay.
- # check if this masking works for that case.
- if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
- boundary_m = tl.full([BLOCK_M],
- actual_seqlen_k,
- dtype=tl.int32)
- size_n = start_n + OFFS_N[None, :]
- mask = size_n < boundary_m[:, None]
- qk = tl.where(mask, qk, float("-inf"))
- if IS_CAUSAL:
- causal_boundary = start_n + offs_n_causal
- causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
- qk = tl.where(causal_mask, qk, float("-inf"))
- # -- compute qk ----
- qk += tl.dot(q, k)
- if bias_ptr is not None:
- bias = load_fn(bias_ptr, False, MASK_STEPS
- and (n_extra_tokens != 0), "zero")
- # While bias is added after multiplying qk with sm_scale, our
- # optimization to use 2^x instead of e^x results in an additional
- # scale factor of log2(e) which we must also multiply the bias with.
- qk += bias * 1.44269504089
- m_ij = tl.maximum(m_i, tl.max(qk, 1))
- qk = qk - m_ij[:, None]
- p = tl.math.exp2(qk)
- # CAVEAT: Must update l_ij before applying dropout
- l_ij = tl.sum(p, 1)
- if ENABLE_DROPOUT:
- philox_offset = (batch_philox_offset +
- start_m * BLOCK_M * actual_seqlen_k + start_n -
- BLOCK_N)
- keep = dropout_mask(
- philox_seed,
- philox_offset,
- dropout_p,
- BLOCK_M,
- BLOCK_N,
- actual_seqlen_k,
- )
- if RETURN_ENCODED_SOFTMAX:
- tl.store(
- encoded_softmax_block_ptr,
- tl.where(keep, p,
- -p).to(encoded_softmax_block_ptr.type.element_ty),
- )
- p = tl.where(keep, p, 0.0)
- elif RETURN_ENCODED_SOFTMAX:
- tl.store(
- encoded_softmax_block_ptr,
- p.to(encoded_softmax_block_ptr.type.element_ty),
- )
- # -- update output accumulator --
- alpha = tl.math.exp2(m_i - m_ij)
- acc = acc * alpha[:, None]
- if not PRE_LOAD_V:
- v = load_fn(
- V_block_ptr,
- MASK_STEPS and (n_extra_tokens != 0),
- PADDED_HEAD,
- "zero",
- )
- # -- update m_i and l_i
- l_i = l_i * alpha + l_ij
- # update m_i and l_i
- m_i = m_ij
- acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
- V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
- K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
- if bias_ptr is not None:
- bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
- if RETURN_ENCODED_SOFTMAX:
- encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
- (0, BLOCK_N))
- return acc, l_i, m_i
- @triton.autotune(
- configs=[
- triton.Config(
- {
- "BLOCK_M": 256,
- "BLOCK_N": 64,
- "waves_per_eu": 2,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=8,
- ),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 128,
- "waves_per_eu": 2,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
- triton.Config(
- {
- "BLOCK_M": 256,
- "BLOCK_N": 128,
- "waves_per_eu": 2,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=8,
- ),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 64,
- "waves_per_eu": 1,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 64,
- "waves_per_eu": 3,
- "PRE_LOAD_V": True,
- },
- num_stages=1,
- num_warps=4,
- ),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 64,
- "waves_per_eu": 3,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
- triton.Config(
- {
- "BLOCK_M": 64,
- "BLOCK_N": 64,
- "waves_per_eu": 4,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=8,
- ),
- triton.Config(
- {
- "BLOCK_M": 32,
- "BLOCK_N": 32,
- "waves_per_eu": 4,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=8,
- ),
- # TODO: This config fails with head_size not pow2 with data mismatches.
- # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
- # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
- triton.Config(
- {
- "BLOCK_M": 16,
- "BLOCK_N": 16,
- "waves_per_eu": 1,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
- ],
- key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
- )
- @triton.jit
- def attn_fwd(
- Q,
- K,
- V,
- bias,
- sm_scale,
- L,
- Out,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- stride_oz,
- stride_oh,
- stride_om,
- stride_on,
- stride_bz,
- stride_bh,
- stride_bm,
- stride_bn,
- cu_seqlens_q,
- cu_seqlens_k,
- dropout_p,
- philox_seed,
- philox_offset_base,
- encoded_softmax,
- HQ: tl.constexpr,
- HK: tl.constexpr,
- ACTUAL_BLOCK_DMODEL: tl.constexpr,
- MAX_SEQLENS_Q: tl.constexpr,
- MAX_SEQLENS_K: tl.constexpr,
- VARLEN: tl.constexpr,
- IS_CAUSAL: tl.constexpr,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- PRE_LOAD_V: tl.constexpr,
- BIAS_TYPE: tl.constexpr,
- ENABLE_DROPOUT: tl.constexpr,
- RETURN_ENCODED_SOFTMAX: tl.constexpr,
- ):
- start_m = tl.program_id(0)
- off_h_q = tl.program_id(1)
- off_z = tl.program_id(2)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- if VARLEN:
- cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
- cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
- seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
- # We have a one-size-fits-all grid in id(0). Some seqlens might be too
- # small for all start_m so for those we return early.
- if start_m * BLOCK_M > seqlen_q:
- return
- cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
- cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
- seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
- else:
- cu_seqlens_q_start = 0
- cu_seqlens_k_start = 0
- seqlen_q = MAX_SEQLENS_Q
- seqlen_k = MAX_SEQLENS_K
- # Now we compute whether we need to exit early due to causal masking.
- # This is because for seqlen_q > seqlen_k, M rows of the attn scores
- # are completely masked, resulting in 0s written to the output, and
- # inf written to LSE. We don't need to do any GEMMs in this case.
- # This block of code determines what N is, and if this WG is operating
- # on those M rows.
- n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
- if IS_CAUSAL:
- # If seqlen_q == seqlen_k, the attn scores are a square matrix.
- # If seqlen_q != seqlen_k, attn scores are rectangular which means
- # the causal mask boundary is bottom right aligned, and ends at either
- # the top edge (seqlen_q < seqlen_k) or left edge.
- # This captures the decrease in n_blocks if we have a rectangular attn
- # matrix
- n_blocks_seqlen = cdiv_fn(
- (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
- # This is what adjusts the block_max for the current WG, only
- # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
- n_blocks = min(n_blocks, n_blocks_seqlen)
- # If we have no blocks after adjusting for seqlen deltas, this WG is
- # part of the blocks that are all 0. We exit early.
- if n_blocks <= 0:
- o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
- off_h_q * stride_oh)
- O_block_ptr = tl.make_block_ptr(
- base=Out + o_offset,
- shape=(seqlen_q, BLOCK_DMODEL),
- strides=(stride_om, stride_on),
- offsets=(start_m * BLOCK_M, 0),
- block_shape=(BLOCK_M, BLOCK_DMODEL),
- order=(1, 0),
- )
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
- # We still need to write 0s to the result
- # tl.store(O_block_ptr,
- # acc.to(Out.type.element_ty), boundary_check=(0,1))
- # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
- # + offs_m
- # We store inf to LSE, not -inf because in the bwd pass,
- # we subtract this
- # from qk which makes it -inf, such that exp(qk - inf) = 0
- # for these masked blocks.
- # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
- # tl.store(l_ptrs, l)
- # TODO: Should dropout and return encoded softmax be handled here?
- return
- # If MQA / GQA, set the K and V head offsets appropriately.
- GROUP_SIZE: tl.constexpr = HQ // HK
- off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
- n_extra_tokens = 0
- if seqlen_k < BLOCK_N:
- n_extra_tokens = BLOCK_N - seqlen_k
- elif seqlen_k % BLOCK_N:
- n_extra_tokens = seqlen_k % BLOCK_N
- padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
- # Compute pointers for all the tensors used in this kernel.
- q_offset = (off_z * stride_qz + off_h_q * stride_qh +
- cu_seqlens_q_start * stride_qm)
- Q_block_ptr = tl.make_block_ptr(
- base=Q + q_offset,
- shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
- strides=(stride_qm, stride_qk),
- offsets=(start_m * BLOCK_M, 0),
- block_shape=(BLOCK_M, BLOCK_DMODEL),
- order=(1, 0),
- )
- k_offset = (off_z * stride_kz + off_h_k * stride_kh +
- cu_seqlens_k_start * stride_kn)
- K_block_ptr = tl.make_block_ptr(
- base=K + k_offset,
- shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
- strides=(stride_kk, stride_kn),
- offsets=(0, 0),
- block_shape=(BLOCK_DMODEL, BLOCK_N),
- order=(0, 1),
- )
- v_offset = (off_z * stride_vz + off_h_k * stride_vh +
- cu_seqlens_k_start * stride_vk)
- V_block_ptr = tl.make_block_ptr(
- base=V + v_offset,
- shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
- strides=(stride_vk, stride_vn),
- offsets=(0, 0),
- block_shape=(BLOCK_N, BLOCK_DMODEL),
- order=(1, 0),
- )
- if BIAS_TYPE != 0:
- bias_ptr = tl.make_block_ptr(
- base=bias + off_h_q * stride_bh,
- shape=(seqlen_q, seqlen_k),
- strides=(stride_bm, stride_bn),
- offsets=(start_m * BLOCK_M, 0),
- block_shape=(BLOCK_M, BLOCK_N),
- order=(1, 0),
- )
- else:
- bias_ptr = None
- if ENABLE_DROPOUT:
- batch_philox_offset = philox_offset_base \
- + (off_z * HQ + off_h_q) \
- * seqlen_q * seqlen_k
- else:
- batch_philox_offset = 0
- # We can ask to return the dropout mask without actually doing any dropout.
- # In this case, we return an invalid pointer so indicate the mask is not i
- # valid.
- # TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
- if RETURN_ENCODED_SOFTMAX:
- encoded_softmax_block_ptr = tl.make_block_ptr(
- base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
- shape=(seqlen_q, seqlen_k),
- strides=(seqlen_k, 1),
- offsets=(start_m * BLOCK_M, 0),
- block_shape=(BLOCK_M, BLOCK_N),
- order=(1, 0),
- )
- else:
- encoded_softmax_block_ptr = 0
- # initialize pointer to m and l
- m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
- l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # scale sm_scale by log_2(e) and use 2^x in the loop as we do not
- # have native e^x support in HW.
- qk_scale = sm_scale * 1.44269504089
- # Q is loaded once at the beginning and shared by all N blocks.
- q = load_fn(Q_block_ptr, True, padded_head, "zero")
- q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
- # Here we compute how many full and masked blocks we have.
- padded_block_k = n_extra_tokens != 0
- is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
- if IS_CAUSAL:
- # There are always at least BLOCK_M // BLOCK_N masked blocks.
- # Additionally there might be one more due to dissimilar seqlens.
- masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
- else:
- # Padding on Q does not need to be masked in the FA loop.
- masked_blocks = padded_block_k
- # if IS_CAUSAL, not is_modulo_mn does not always result in an additional
- # block. In this case we might exceed n_blocks so pick the min.
- masked_blocks = min(masked_blocks, n_blocks)
- n_full_blocks = n_blocks - masked_blocks
- block_min = 0
- block_max = n_blocks * BLOCK_N
- # Compute for full blocks. Here we set causal to false regardless of its
- # value because there is no masking. Similarly we do not need padding.
- if n_full_blocks > 0:
- block_max = (n_blocks - masked_blocks) * BLOCK_N
- acc, l_i, m_i = _attn_fwd_inner(
- acc,
- l_i,
- m_i,
- q,
- K_block_ptr,
- V_block_ptr,
- start_m,
- seqlen_k,
- dropout_p,
- philox_seed,
- batch_philox_offset,
- encoded_softmax_block_ptr,
- # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
- block_min,
- block_max,
- 0,
- 0,
- 0,
- bias_ptr,
- # IS_CAUSAL, ....
- False,
- BLOCK_M,
- BLOCK_DMODEL,
- BLOCK_N,
- offs_m,
- offs_n,
- # _, MASK_STEPS, ...
- PRE_LOAD_V,
- False,
- ENABLE_DROPOUT,
- RETURN_ENCODED_SOFTMAX,
- padded_head,
- )
- block_min = block_max
- block_max = n_blocks * BLOCK_N
- tl.debug_barrier()
- # Remaining blocks, if any, are full / not masked.
- if masked_blocks > 0:
- offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
- K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
- V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
- if bias_ptr is not None:
- bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
- if RETURN_ENCODED_SOFTMAX:
- encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
- (0, n_full_blocks))
- acc, l_i, m_i = _attn_fwd_inner(
- acc,
- l_i,
- m_i,
- q,
- K_block_ptr,
- V_block_ptr,
- start_m,
- seqlen_k,
- dropout_p,
- philox_seed,
- batch_philox_offset,
- encoded_softmax_block_ptr,
- block_min,
- block_max,
- offs_n_causal,
- masked_blocks,
- n_extra_tokens,
- bias_ptr,
- IS_CAUSAL,
- BLOCK_M,
- BLOCK_DMODEL,
- BLOCK_N,
- offs_m,
- offs_n,
- # _, MASK_STEPS, ...
- PRE_LOAD_V,
- True,
- ENABLE_DROPOUT,
- RETURN_ENCODED_SOFTMAX,
- padded_head,
- )
- # epilogue
- acc = acc / l_i[:, None]
- if ENABLE_DROPOUT:
- acc = acc / (1 - dropout_p)
- # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
- # then we have one block with a row of all NaNs which come from computing
- # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
- # and store 0s where there are NaNs as these rows should've been zeroed out.
- end_m_idx = (start_m + 1) * BLOCK_M
- start_m_idx = start_m * BLOCK_M
- causal_start_idx = seqlen_q - seqlen_k
- acc = acc.to(Out.type.element_ty)
- if IS_CAUSAL: # noqa: SIM102
- if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
- out_mask_boundary = tl.full((BLOCK_DMODEL, ),
- causal_start_idx,
- dtype=tl.int32)
- mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
- out_ptrs_mask = (mask_m_offsets[:, None] >=
- out_mask_boundary[None, :])
- z = 0.0
- acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
- # write back LSE
- # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
- # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
- # few rows. This is only true for the last M block. For others,
- # overflow_size will be -ve
- # overflow_size = end_m_idx - seqlen_q
- # if overflow_size > 0:
- # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
- # # This is a > check because mask being 0 blocks the store.
- # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
- # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
- # else:
- # tl.store(l_ptrs, m_i + tl.math.log2(l_i))
- # write back O
- o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
- off_h_q * stride_oh)
- O_block_ptr = tl.make_block_ptr(
- base=Out + o_offset,
- shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
- strides=(stride_om, stride_on),
- offsets=(start_m * BLOCK_M, 0),
- block_shape=(BLOCK_M, BLOCK_DMODEL),
- order=(1, 0),
- )
- # Need boundary check on this to make sure the padding from the
- # Q and KV tensors in both dims are not part of what we store back.
- # TODO: Do the boundary check optionally.
- tl.store(O_block_ptr, acc, boundary_check=(0, 1))
- def check_args(
- q,
- k,
- v,
- o,
- varlen=True,
- max_seqlens=None,
- cu_seqlens_q=None,
- cu_seqlens_k=None,
- ):
- assert q.dim() == k.dim() and q.dim() == v.dim()
- if varlen:
- assert q.dim() == 3
- total_q, nheads_q, head_size = q.shape
- total_k, nheads_k, _ = k.shape
- assert cu_seqlens_q is not None
- assert cu_seqlens_k is not None
- assert len(cu_seqlens_q) == len(cu_seqlens_k)
- else:
- assert q.dim() == 4
- batch, nheads_q, seqlen_q, head_size = q.shape
- _, nheads_k, seqlen_k, _ = k.shape
- assert max_seqlens > 0
- assert k.shape == v.shape
- assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
- # TODO: Change assert if we support qkl f8 and v f16
- assert q.dtype == k.dtype and q.dtype == v.dtype
- assert head_size <= 256
- assert o.shape == q.shape
- assert (nheads_q % nheads_k) == 0
- class _attention(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- q,
- k,
- v,
- o,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlens_q,
- max_seqlens_k,
- causal=False,
- sm_scale=1.0,
- bias=None,
- ):
- if o is None:
- o = torch.empty_like(q, dtype=v.dtype)
- check_args(
- q,
- k,
- v,
- o,
- varlen=True,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- )
- if True: # varlen
- total_q, nheads_q, head_size = q.shape
- total_k, nheads_k, _ = k.shape
- batch = len(cu_seqlens_q) - 1
- q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
- k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
- v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
- o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
- else:
- batch, seqlen_q, nheads_q, head_size = q.shape
- _, seqlen_k, nheads_k, _ = k.shape
- q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
- k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
- v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
- o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
- # Get closest power of 2 over or equal to 32.
- unpadded_head_dims = {32, 64, 128, 256}
- if head_size not in unpadded_head_dims:
- padded_d_model = None
- for i in unpadded_head_dims:
- if i > head_size:
- padded_d_model = i
- break
- assert padded_d_model is not None
- else:
- padded_d_model = head_size
- grid = lambda META: (
- triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
- nheads_q,
- batch,
- )
- encoded_softmax = None
- # Seed the RNG so we get reproducible results for testing.
- philox_seed = 0x1BF52
- philox_offset = 0x1D4B42
- if bias is not None:
- bias_strides = (
- bias.stride(0),
- bias.stride(1),
- bias.stride(2),
- bias.stride(3),
- )
- else:
- bias_strides = (0, 0, 0, 0)
- attn_fwd[grid](
- q,
- k,
- v,
- bias,
- sm_scale,
- None,
- o,
- *q_strides,
- *k_strides,
- *v_strides,
- *o_strides,
- *bias_strides,
- cu_seqlens_q,
- cu_seqlens_k,
- dropout_p=0.0,
- philox_seed=philox_seed,
- philox_offset_base=philox_offset,
- encoded_softmax=encoded_softmax,
- HQ=nheads_q,
- HK=nheads_k,
- ACTUAL_BLOCK_DMODEL=head_size,
- MAX_SEQLENS_Q=max_seqlens_q,
- MAX_SEQLENS_K=max_seqlens_k,
- IS_CAUSAL=causal,
- VARLEN=True,
- BLOCK_DMODEL=padded_d_model,
- BIAS_TYPE=0 if bias is None else 1,
- ENABLE_DROPOUT=False,
- RETURN_ENCODED_SOFTMAX=False,
- )
- ctx.grid = grid
- ctx.sm_scale = sm_scale
- ctx.BLOCK_DMODEL = head_size
- ctx.causal = causal
- ctx.dropout_p = 0.0
- ctx.philox_seed = philox_seed
- ctx.philox_offset = philox_offset
- ctx.encoded_softmax = encoded_softmax
- ctx.return_encoded_softmax = False
- return o, encoded_softmax
- triton_attention = _attention.apply
|