test_attention.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import random
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.common.utils import (get_max_shared_memory_bytes, is_hip,
  7. seed_everything)
  8. from tests.kernels.utils import opcheck
  9. from .allclose_default import get_default_atol, get_default_rtol
  10. if not is_hip():
  11. from xformers import ops as xops
  12. from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
  13. FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
  14. # This will change depending on the compute capability.
  15. # - 512 as a buffer
  16. MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
  17. # There may not be enough gpu memory due to large NUM_BLOCKS.
  18. # Reduce NUM_BLOCKS when it happens.
  19. NUM_BLOCKS = 4321 # Arbitrary values for testing
  20. PARTITION_SIZE = 512
  21. # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
  22. DTYPES = [torch.half, torch.bfloat16, torch.float
  23. ] if not is_hip() else [torch.half, torch.bfloat16]
  24. NUM_GEN_SEQS = [7] # Arbitrary values for testing
  25. NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
  26. NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
  27. # FlashAttention forward only supports head dimension at most 128
  28. # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
  29. HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
  30. BLOCK_SIZES = [16, 32]
  31. USE_ALIBI = [False, True]
  32. KV_CACHE_DTYPE = ["auto", "fp8"]
  33. SEEDS = [0]
  34. CUDA_DEVICES = [
  35. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  36. ]
  37. def ref_masked_attention(
  38. query: torch.Tensor,
  39. key: torch.Tensor,
  40. value: torch.Tensor,
  41. scale: float,
  42. attn_mask: Optional[torch.Tensor] = None,
  43. ) -> torch.Tensor:
  44. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  45. if attn_mask is not None:
  46. attn_weights = attn_weights + attn_mask.float()
  47. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  48. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  49. return out
  50. def ref_single_query_cached_kv_attention(
  51. output: torch.Tensor,
  52. query: torch.Tensor,
  53. num_queries_per_kv: int,
  54. key_cache: torch.Tensor,
  55. value_cache: torch.Tensor,
  56. block_tables: torch.Tensor,
  57. seq_lens: torch.Tensor,
  58. scale: float,
  59. alibi_slopes: Optional[torch.Tensor],
  60. ) -> None:
  61. num_query_heads = query.shape[1]
  62. num_kv_heads = value_cache.shape[1]
  63. head_size = value_cache.shape[2]
  64. block_size = value_cache.shape[3]
  65. num_seqs = query.shape[0]
  66. block_tables_lst = block_tables.cpu().tolist()
  67. seq_lens_lst = seq_lens.cpu().tolist()
  68. for i in range(num_seqs):
  69. q = query[i].unsqueeze(0)
  70. block_table = block_tables_lst[i]
  71. seq_len = int(seq_lens_lst[i])
  72. keys_lst: List[torch.Tensor] = []
  73. values_lst: List[torch.Tensor] = []
  74. for j in range(seq_len):
  75. block_number = int(block_table[j // block_size])
  76. block_offset = j % block_size
  77. k = key_cache[block_number, :, :, block_offset, :]
  78. k = k.reshape(num_kv_heads, head_size)
  79. keys_lst.append(k)
  80. v = value_cache[block_number, :, :, block_offset]
  81. values_lst.append(v)
  82. keys = torch.stack(keys_lst, dim=0)
  83. values = torch.stack(values_lst, dim=0)
  84. if num_queries_per_kv > 1:
  85. # Handle MQA and GQA
  86. keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
  87. values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
  88. alibi_bias = None
  89. if alibi_slopes is not None:
  90. # Create the ALiBi bias used in the paged attention kernel.
  91. position_ids = torch.arange(seq_len).int()
  92. alibi_bias = (position_ids - seq_len + 1).float()
  93. alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
  94. 1, 1, -1)
  95. out = ref_masked_attention(q, keys, values, scale, alibi_bias)
  96. out = out.view(num_query_heads, head_size)
  97. output[i].copy_(out, non_blocking=True)
  98. @pytest.mark.parametrize(
  99. "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
  100. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  101. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  102. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  103. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  104. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  105. @pytest.mark.parametrize("dtype", DTYPES)
  106. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  107. @pytest.mark.parametrize("seed", SEEDS)
  108. @pytest.mark.parametrize("device", CUDA_DEVICES)
  109. def test_paged_attention(
  110. kv_cache_factory,
  111. version: str,
  112. num_seqs: int,
  113. num_heads: Tuple[int, int],
  114. head_size: int,
  115. use_alibi: bool,
  116. block_size: int,
  117. dtype: torch.dtype,
  118. kv_cache_dtype: str,
  119. seed: int,
  120. device: str,
  121. ) -> None:
  122. if ((kv_cache_dtype == "fp8" and head_size % 16)
  123. or (version == "rocm" and head_size not in (64, 128))):
  124. pytest.skip()
  125. seed_everything(seed)
  126. torch.set_default_device(device)
  127. scale = float(1.0 / (head_size**0.5))
  128. num_query_heads, num_kv_heads = num_heads
  129. query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
  130. query.uniform_(-scale, scale)
  131. assert num_query_heads % num_kv_heads == 0
  132. num_queries_per_kv = num_query_heads // num_kv_heads
  133. alibi_slopes = None
  134. if use_alibi:
  135. alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
  136. seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  137. seq_lens[-1] = MAX_SEQ_LEN
  138. max_seq_len = max(seq_lens)
  139. seq_lens = torch.tensor(seq_lens, dtype=torch.int)
  140. # Create the block tables.
  141. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
  142. block_tables_lst: List[List[int]] = []
  143. for _ in range(num_seqs):
  144. block_table = [
  145. random.randint(0, NUM_BLOCKS - 1)
  146. for _ in range(max_num_blocks_per_seq)
  147. ]
  148. block_tables_lst.append(block_table)
  149. block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
  150. # Create the KV caches.
  151. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  152. num_kv_heads, head_size,
  153. kv_cache_dtype, dtype, seed,
  154. device)
  155. key_cache, value_cache = key_caches[0], value_caches[0]
  156. # Using default kv_scale
  157. k_scale = v_scale = 1.0
  158. # Call the paged attention kernel.
  159. output = torch.empty_like(query)
  160. if version == "v1":
  161. ops.paged_attention_v1(
  162. output,
  163. query,
  164. key_cache,
  165. value_cache,
  166. num_kv_heads,
  167. scale,
  168. block_tables,
  169. seq_lens,
  170. block_size,
  171. max_seq_len,
  172. alibi_slopes,
  173. kv_cache_dtype,
  174. k_scale,
  175. v_scale,
  176. )
  177. opcheck(torch.ops._C.paged_attention_v1,
  178. (output, query, key_cache, value_cache, num_kv_heads, scale,
  179. block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
  180. kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
  181. cond=(head_size == HEAD_SIZES[0]))
  182. elif version in ("v2", "rocm"):
  183. num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
  184. assert PARTITION_SIZE % block_size == 0
  185. num_seqs, num_heads, head_size = output.shape
  186. tmp_output = torch.empty(
  187. size=(num_seqs, num_heads, num_partitions, head_size),
  188. dtype=output.dtype,
  189. )
  190. exp_sums = torch.empty(
  191. size=(num_seqs, num_heads, num_partitions),
  192. dtype=torch.float32,
  193. )
  194. max_logits = torch.empty_like(exp_sums)
  195. if version == "v2":
  196. ops.paged_attention_v2(
  197. output,
  198. exp_sums,
  199. max_logits,
  200. tmp_output,
  201. query,
  202. key_cache,
  203. value_cache,
  204. num_kv_heads,
  205. scale,
  206. block_tables,
  207. seq_lens,
  208. block_size,
  209. max_seq_len,
  210. alibi_slopes,
  211. kv_cache_dtype,
  212. k_scale,
  213. v_scale,
  214. )
  215. opcheck(torch.ops._C.paged_attention_v2,
  216. (output, exp_sums, max_logits, tmp_output, query,
  217. key_cache, value_cache, num_kv_heads, scale, block_tables,
  218. seq_lens, block_size, max_seq_len, alibi_slopes,
  219. kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
  220. cond=(head_size == HEAD_SIZES[0]))
  221. else:
  222. ops.paged_attention_rocm(
  223. output,
  224. exp_sums,
  225. max_logits,
  226. tmp_output,
  227. query,
  228. key_cache,
  229. value_cache,
  230. num_kv_heads,
  231. scale,
  232. block_tables,
  233. seq_lens,
  234. block_size,
  235. max_seq_len,
  236. alibi_slopes,
  237. kv_cache_dtype,
  238. k_scale,
  239. v_scale,
  240. )
  241. opcheck(torch.ops._rocm_C.paged_attention,
  242. (output, exp_sums, max_logits, tmp_output, query,
  243. key_cache, value_cache, num_kv_heads, scale, block_tables,
  244. seq_lens, block_size, max_seq_len, alibi_slopes,
  245. kv_cache_dtype, k_scale, v_scale),
  246. cond=(head_size == HEAD_SIZES[0]))
  247. else:
  248. raise AssertionError(f"Unknown version: {version}")
  249. # Run the reference implementation.
  250. if kv_cache_dtype == "fp8":
  251. # Convert cache data back to dtype.
  252. x = 16 // torch.tensor([], dtype=dtype).element_size()
  253. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
  254. block_size, x)
  255. dequantized_key_cache = torch.empty(size=key_cache_shape,
  256. dtype=dtype,
  257. device=device)
  258. ops.convert_fp8(dequantized_key_cache, key_cache)
  259. key_cache = dequantized_key_cache
  260. value_cache_shape = value_cache.shape
  261. dequantized_value_cache = torch.empty(size=value_cache_shape,
  262. dtype=dtype,
  263. device=device)
  264. ops.convert_fp8(dequantized_value_cache, value_cache)
  265. value_cache = dequantized_value_cache
  266. ref_output = torch.empty_like(query)
  267. ref_single_query_cached_kv_attention(
  268. ref_output,
  269. query,
  270. num_queries_per_kv,
  271. key_cache,
  272. value_cache,
  273. block_tables,
  274. seq_lens,
  275. scale,
  276. alibi_slopes,
  277. )
  278. # NOTE(woosuk): Due to the kernel-level differences in the two
  279. # implementations, there is a small numerical difference in the two
  280. # outputs. Thus, we use a relaxed tolerance for the test.
  281. atol = get_default_atol(output) if is_hip() else 1e-3
  282. rtol = get_default_rtol(output) if is_hip() else 1e-5
  283. # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
  284. # so we use a relaxed tolerance for the test.
  285. atol, rtol = 1e-3, 1e-5
  286. if kv_cache_dtype == "fp8":
  287. atol, rtol = 1e-2, 1e-5
  288. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
  289. def ref_multi_query_kv_attention(
  290. cu_seq_lens: List[int],
  291. query: torch.Tensor,
  292. key: torch.Tensor,
  293. value: torch.Tensor,
  294. scale: float,
  295. dtype: torch.dtype,
  296. ) -> torch.Tensor:
  297. num_seqs = len(cu_seq_lens) - 1
  298. ref_outputs: List[torch.Tensor] = []
  299. for i in range(num_seqs):
  300. start_idx = cu_seq_lens[i]
  301. end_idx = cu_seq_lens[i + 1]
  302. seq_len = end_idx - start_idx
  303. # Create attention mask.
  304. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
  305. diagonal=1)
  306. attn_mask = attn_mask * torch.finfo(dtype).min
  307. attn_mask = attn_mask.to(dtype=dtype)
  308. ref_output = ref_masked_attention(
  309. query[start_idx:end_idx],
  310. key[start_idx:end_idx],
  311. value[start_idx:end_idx],
  312. scale,
  313. attn_mask=attn_mask,
  314. )
  315. ref_outputs.append(ref_output)
  316. return torch.cat(ref_outputs, dim=0)
  317. # TODO(woosuk): Add tests for USE_ALIBI=True.
  318. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
  319. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  320. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  321. @pytest.mark.parametrize("dtype", DTYPES)
  322. @pytest.mark.parametrize("seed", SEEDS)
  323. @pytest.mark.parametrize("device", CUDA_DEVICES)
  324. @pytest.mark.skipif(is_hip(),
  325. reason="Xformers backend is not supported on ROCm.")
  326. @torch.inference_mode()
  327. def test_multi_query_kv_attention(
  328. num_seqs: int,
  329. num_heads: Tuple[int, int],
  330. head_size: int,
  331. dtype: torch.dtype,
  332. seed: int,
  333. device: str,
  334. ) -> None:
  335. seed_everything(seed)
  336. torch.set_default_device(device)
  337. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
  338. # As the xformers library is already tested with its own tests, we can use
  339. # a smaller MAX_SEQ_LEN here.
  340. max_len = min(MAX_SEQ_LEN, 4096)
  341. seq_lens = random.sample(range(1, max_len), num_seqs)
  342. num_tokens = sum(seq_lens)
  343. scale = float(1.0 / (head_size**0.5))
  344. num_query_heads, num_kv_heads = num_heads
  345. qkv = torch.empty(num_tokens,
  346. num_query_heads + 2 * num_kv_heads,
  347. head_size,
  348. dtype=dtype)
  349. qkv.uniform_(-scale, scale)
  350. query, key, value = qkv.split(
  351. [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
  352. num_queries_per_kv = num_query_heads // num_kv_heads
  353. if num_queries_per_kv > 1:
  354. # Handle MQA and GQA
  355. key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
  356. value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
  357. attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
  358. output = xops.memory_efficient_attention_forward(
  359. query.unsqueeze(0),
  360. key.unsqueeze(0),
  361. value.unsqueeze(0),
  362. attn_bias=attn_bias,
  363. p=0.0,
  364. scale=scale,
  365. )
  366. output = output.squeeze(0)
  367. cu_seq_lens = [0]
  368. for seq_len in seq_lens:
  369. cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
  370. ref_output = ref_multi_query_kv_attention(
  371. cu_seq_lens,
  372. query,
  373. key,
  374. value,
  375. scale,
  376. dtype,
  377. )
  378. atol = get_default_atol(output) if is_hip() else 1e-3
  379. rtol = get_default_rtol(output) if is_hip() else 1e-5
  380. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)