123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- import torch
- import os
- import triton
- AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
- DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes')
- PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
- class MetaData():
- cu_seqlens_q = None
- cu_seqlens_k = None
- max_seqlens_q = 0
- max_seqlens_k = 0
- bias = None
- alibi_slopes = None
- causal = False
- num_contexts = 0
- varlen = False
- layout = None
- cache_seqlens = None
- cache_batch_idx = None
- new_kv = False
- seqlen_new = None
- k_new = None
- v_new = None
- dropout_p, return_scores= 0.0, False
- # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
- use_exp2 = False
-
- def __repr__(self) -> str:
- return (f"MetaData(\n"
- f" sm_scale={self.sm_scale},\n"
- f" cu_seqlens_q={self.cu_seqlens_q},\n"
- f" cu_seqlens_k={self.cu_seqlens_k},\n"
- f" max_seqlens_q={self.max_seqlens_q},\n"
- f" max_seqlens_k={self.max_seqlens_k},\n"
- f" bias={self.bias},\n"
- f" alibi_slopes={self.alibi_slopes},\n"
- f" causal={self.causal},\n"
- f" num_contexts={self.num_contexts},\n"
- f" varlen={self.varlen},\n"
- f" layout={self.layout},\n"
- f" cache_seqlens={self.cache_seqlens},\n"
- f" cache_batch_idx={self.cache_batch_idx},\n"
- f" new_kv={self.new_kv},\n"
- f" seqlen_new={self.seqlen_new},\n"
- f" k_new={self.k_new},\n"
- f" v_new={self.v_new},\n"
- f" dropout_p={self.dropout_p},\n"
- f" return_scores={self.return_scores}\n"
- f")")
- def __init__(self, sm_scale=1.0):
- self.sm_scale = sm_scale
- def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
- self.varlen = True
- self.layout = 'thd'
- self.cu_seqlens_q = cu_seqlens_q
- self.cu_seqlens_k = cu_seqlens_k
- # Without "varlen", there should still be one sequence.
- assert len(cu_seqlens_q) >= 2
- assert len(cu_seqlens_q) == len(cu_seqlens_k)
- self.num_contexts = len(cu_seqlens_q) - 1
- for i in range(0, self.num_contexts):
- self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
- self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
- def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
- assert bias.is_cuda
- assert bias.dim() == 4
- assert bias.shape[0] == 1
- assert bias.shape[2:] == (seqlen_q, seqlen_k)
- self.bias = bias
- def need_alibi(self, alibi_slopes, batch, nheads):
- assert alibi_slopes.is_cuda
- assert alibi_slopes.dim() == 2
- assert alibi_slopes.shape[0] == batch
- assert alibi_slopes.shape[1] == nheads
- self.alibi_slopes = alibi_slopes
- def need_causal(self):
- self.causal = True
- def need_dropout(self, dropout_p, return_scores):
- self.dropout_p = dropout_p
- self.return_scores = return_scores
- def check_args(self, q, k, v, o):
- assert q.dim() == k.dim() and q.dim() == v.dim()
- batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k)
- if self.varlen:
- assert q.dim() == 3
- assert self.cu_seqlens_q is not None
- assert self.cu_seqlens_k is not None
- assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
- # TODO: Remove once bias is supported with varlen
- assert self.bias is None
- # TODO:Remove once dropout is supported with varlen
- assert self.dropout_p == 0.0
- # assert not self.return_scores
- else:
- assert q.dim() == 4
- assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
- assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
- assert k.shape == v.shape
- assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
- # TODO: Change assert if we support qkl f8 and v f16
- assert q.dtype == k.dtype and q.dtype == v.dtype
- assert head_size <= 256
- assert o.shape == q.shape
- assert (nheads_q % nheads_k) == 0
- assert self.layout is not None
- assert self.layout == 'thd' or not self.varlen
- def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False):
- torch.manual_seed(20)
- # Initialize q, k, v
- if layout == 'bhsd':
- q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
- k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
- elif layout == 'bshd':
- q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
- k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
- else:
- assert False, f'Got unsupported tensor layout: {layout}'
- if DEBUG_INPUT:
- if layout == "bhsd":
- q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
- k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
- v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
- elif layout == "bshd":
- q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
- k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
- v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
- else:
- q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True)
- k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
- v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
-
- if DEBUG_INPUT:
- sm_scale = 1
- else:
- sm_scale = D_HEAD**-0.5
- input_metadata = MetaData(sm_scale=sm_scale)
- input_metadata.max_seqlens_q = N_CTX_Q
- input_metadata.max_seqlens_k = N_CTX_K
- input_metadata.layout = layout
- return q, k, v, input_metadata
- def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
- torch.manual_seed(20)
- # Random or equal sequence lengths based on 'equal_seqlens' flag
- if not equal_seqlens:
- max_seqlens_q = N_CTX_Q // Z
- max_seqlens_k = N_CTX_K // Z
- seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
- seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
- else:
- seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32)
- seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32)
- # Calculate cumulative sequence lengths
- cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)])
- cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)])
- cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32)
- cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32)
- # Total lengths
- total_q = cu_seqlens_q[-1].item()
- total_k = cu_seqlens_k[-1].item()
- if DEBUG_INPUT:
- # Initialize q, k, v with deterministic values
- q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
- q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
- k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
- k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
- v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
- v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
- sm_scale = 1
- else:
- # Initialize q, k, v with random values
- q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
- k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
- v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
- sm_scale = D_HEAD ** -0.5
- input_metadata = MetaData(sm_scale=sm_scale)
- input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
- return q, k, v, input_metadata
- def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
- if layout == 'bhsd':
- batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape
- batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape
- elif layout == 'bshd':
- batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape
- batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape
- elif layout == 'thd':
- batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2]
- batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2]
- else:
- assert False, "Got unsupported layout."
-
- # assert
- assert batch_q == batch_k
- assert head_size_q == head_size_k
- return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k
- def get_strides_from_layout(q, k, v, o, layout):
- if layout == 'thd':
- q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
- k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
- v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
- o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
- elif layout == 'bhsd':
- q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
- k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
- v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
- o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
- elif layout == 'bshd':
- q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
- k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
- v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
- o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
- else:
- assert False, 'Got unsupported layout.'
- return q_strides, k_strides, v_strides, o_strides
- def get_padded_headsize(size):
- # Get closest power of 2 over or equal to 32.
- padded_d_model = 1 << (size - 1).bit_length()
- # Smallest head_dim supported is 16. If smaller, the tile in the
- # kernel is padded - there is no padding in memory for any dims.
- padded_d_model = max(padded_d_model, 16)
- return padded_d_model
- def _strides(x: torch.Tensor, *stride_names: str):
- if x is None:
- return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}
- assert x.ndim == len(stride_names)
- return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
- def get_input_shapes():
- cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128)
- for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)]
- return cases
- def is_hip():
- return triton.runtime.driver.active.get_current_target().backend == "hip"
- def is_cdna():
- return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
- 'gfx90a', 'gfx908')
- def is_rdna():
- return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",
- "gfx1102", "gfx1200", "gfx1201")
|