utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import torch
  2. import os
  3. import triton
  4. AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
  5. DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes')
  6. PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
  7. class MetaData():
  8. cu_seqlens_q = None
  9. cu_seqlens_k = None
  10. max_seqlens_q = 0
  11. max_seqlens_k = 0
  12. bias = None
  13. alibi_slopes = None
  14. causal = False
  15. num_contexts = 0
  16. varlen = False
  17. layout = None
  18. cache_seqlens = None
  19. cache_batch_idx = None
  20. new_kv = False
  21. seqlen_new = None
  22. k_new = None
  23. v_new = None
  24. dropout_p, return_scores= 0.0, False
  25. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
  26. use_exp2 = False
  27. def __repr__(self) -> str:
  28. return (f"MetaData(\n"
  29. f" sm_scale={self.sm_scale},\n"
  30. f" cu_seqlens_q={self.cu_seqlens_q},\n"
  31. f" cu_seqlens_k={self.cu_seqlens_k},\n"
  32. f" max_seqlens_q={self.max_seqlens_q},\n"
  33. f" max_seqlens_k={self.max_seqlens_k},\n"
  34. f" bias={self.bias},\n"
  35. f" alibi_slopes={self.alibi_slopes},\n"
  36. f" causal={self.causal},\n"
  37. f" num_contexts={self.num_contexts},\n"
  38. f" varlen={self.varlen},\n"
  39. f" layout={self.layout},\n"
  40. f" cache_seqlens={self.cache_seqlens},\n"
  41. f" cache_batch_idx={self.cache_batch_idx},\n"
  42. f" new_kv={self.new_kv},\n"
  43. f" seqlen_new={self.seqlen_new},\n"
  44. f" k_new={self.k_new},\n"
  45. f" v_new={self.v_new},\n"
  46. f" dropout_p={self.dropout_p},\n"
  47. f" return_scores={self.return_scores}\n"
  48. f")")
  49. def __init__(self, sm_scale=1.0):
  50. self.sm_scale = sm_scale
  51. def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
  52. self.varlen = True
  53. self.layout = 'thd'
  54. self.cu_seqlens_q = cu_seqlens_q
  55. self.cu_seqlens_k = cu_seqlens_k
  56. # Without "varlen", there should still be one sequence.
  57. assert len(cu_seqlens_q) >= 2
  58. assert len(cu_seqlens_q) == len(cu_seqlens_k)
  59. self.num_contexts = len(cu_seqlens_q) - 1
  60. for i in range(0, self.num_contexts):
  61. self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
  62. self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
  63. def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
  64. assert bias.is_cuda
  65. assert bias.dim() == 4
  66. assert bias.shape[0] == 1
  67. assert bias.shape[2:] == (seqlen_q, seqlen_k)
  68. self.bias = bias
  69. def need_alibi(self, alibi_slopes, batch, nheads):
  70. assert alibi_slopes.is_cuda
  71. assert alibi_slopes.dim() == 2
  72. assert alibi_slopes.shape[0] == batch
  73. assert alibi_slopes.shape[1] == nheads
  74. self.alibi_slopes = alibi_slopes
  75. def need_causal(self):
  76. self.causal = True
  77. def need_dropout(self, dropout_p, return_scores):
  78. self.dropout_p = dropout_p
  79. self.return_scores = return_scores
  80. def check_args(self, q, k, v, o):
  81. assert q.dim() == k.dim() and q.dim() == v.dim()
  82. batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k)
  83. if self.varlen:
  84. assert q.dim() == 3
  85. assert self.cu_seqlens_q is not None
  86. assert self.cu_seqlens_k is not None
  87. assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
  88. # TODO: Remove once bias is supported with varlen
  89. assert self.bias is None
  90. # TODO:Remove once dropout is supported with varlen
  91. assert self.dropout_p == 0.0
  92. # assert not self.return_scores
  93. else:
  94. assert q.dim() == 4
  95. assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
  96. assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
  97. assert k.shape == v.shape
  98. assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
  99. # TODO: Change assert if we support qkl f8 and v f16
  100. assert q.dtype == k.dtype and q.dtype == v.dtype
  101. assert head_size <= 256
  102. assert o.shape == q.shape
  103. assert (nheads_q % nheads_k) == 0
  104. assert self.layout is not None
  105. assert self.layout == 'thd' or not self.varlen
  106. def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False):
  107. torch.manual_seed(20)
  108. # Initialize q, k, v
  109. if layout == 'bhsd':
  110. q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
  111. k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
  112. elif layout == 'bshd':
  113. q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
  114. k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
  115. else:
  116. assert False, f'Got unsupported tensor layout: {layout}'
  117. if DEBUG_INPUT:
  118. if layout == "bhsd":
  119. q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
  120. k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
  121. v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
  122. elif layout == "bshd":
  123. q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
  124. k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
  125. v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
  126. else:
  127. q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True)
  128. k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
  129. v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
  130. if DEBUG_INPUT:
  131. sm_scale = 1
  132. else:
  133. sm_scale = D_HEAD**-0.5
  134. input_metadata = MetaData(sm_scale=sm_scale)
  135. input_metadata.max_seqlens_q = N_CTX_Q
  136. input_metadata.max_seqlens_k = N_CTX_K
  137. input_metadata.layout = layout
  138. return q, k, v, input_metadata
  139. def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
  140. torch.manual_seed(20)
  141. # Random or equal sequence lengths based on 'equal_seqlens' flag
  142. if not equal_seqlens:
  143. max_seqlens_q = N_CTX_Q // Z
  144. max_seqlens_k = N_CTX_K // Z
  145. seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
  146. seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
  147. else:
  148. seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32)
  149. seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32)
  150. # Calculate cumulative sequence lengths
  151. cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)])
  152. cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)])
  153. cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32)
  154. cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32)
  155. # Total lengths
  156. total_q = cu_seqlens_q[-1].item()
  157. total_k = cu_seqlens_k[-1].item()
  158. if DEBUG_INPUT:
  159. # Initialize q, k, v with deterministic values
  160. q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
  161. q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
  162. k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
  163. k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
  164. v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
  165. v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
  166. sm_scale = 1
  167. else:
  168. # Initialize q, k, v with random values
  169. q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
  170. k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
  171. v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
  172. sm_scale = D_HEAD ** -0.5
  173. input_metadata = MetaData(sm_scale=sm_scale)
  174. input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
  175. return q, k, v, input_metadata
  176. def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
  177. if layout == 'bhsd':
  178. batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape
  179. batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape
  180. elif layout == 'bshd':
  181. batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape
  182. batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape
  183. elif layout == 'thd':
  184. batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2]
  185. batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2]
  186. else:
  187. assert False, "Got unsupported layout."
  188. # assert
  189. assert batch_q == batch_k
  190. assert head_size_q == head_size_k
  191. return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k
  192. def get_strides_from_layout(q, k, v, o, layout):
  193. if layout == 'thd':
  194. q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
  195. k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
  196. v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
  197. o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
  198. elif layout == 'bhsd':
  199. q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
  200. k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
  201. v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
  202. o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
  203. elif layout == 'bshd':
  204. q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
  205. k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
  206. v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
  207. o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
  208. else:
  209. assert False, 'Got unsupported layout.'
  210. return q_strides, k_strides, v_strides, o_strides
  211. def get_padded_headsize(size):
  212. # Get closest power of 2 over or equal to 32.
  213. padded_d_model = 1 << (size - 1).bit_length()
  214. # Smallest head_dim supported is 16. If smaller, the tile in the
  215. # kernel is padded - there is no padding in memory for any dims.
  216. padded_d_model = max(padded_d_model, 16)
  217. return padded_d_model
  218. def _strides(x: torch.Tensor, *stride_names: str):
  219. if x is None:
  220. return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}
  221. assert x.ndim == len(stride_names)
  222. return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
  223. def get_input_shapes():
  224. cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128)
  225. for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)]
  226. return cases
  227. def is_hip():
  228. return triton.runtime.driver.active.get_current_target().backend == "hip"
  229. def is_cdna():
  230. return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
  231. 'gfx90a', 'gfx908')
  232. def is_rdna():
  233. return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",
  234. "gfx1102", "gfx1200", "gfx1201")