test_attention.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. import random
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip
  7. from tests.kernels.utils import opcheck
  8. from .allclose_default import get_default_atol, get_default_rtol
  9. if not is_hip():
  10. from xformers import ops as xops
  11. from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
  12. FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
  13. # This will change depending on the compute capability.
  14. # - 512 as a buffer
  15. MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
  16. # There may not be enough gpu memory due to large NUM_BLOCKS.
  17. # Reduce NUM_BLOCKS when it happens.
  18. NUM_BLOCKS = 4321 # Arbitrary values for testing
  19. PARTITION_SIZE = 512
  20. # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
  21. DTYPES = [torch.half, torch.bfloat16, torch.float
  22. ] if not is_hip() else [torch.half, torch.bfloat16]
  23. NUM_GEN_SEQS = [7] # Arbitrary values for testing
  24. NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
  25. NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
  26. # FlashAttention forward only supports head dimension at most 128
  27. # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
  28. HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
  29. ] if not is_hip() else [64, 80, 96, 112, 128]
  30. BLOCK_SIZES = [16, 32]
  31. USE_ALIBI = [False, True]
  32. KV_CACHE_DTYPE = ["auto", "fp8"]
  33. SEEDS = [0]
  34. CUDA_DEVICES = [
  35. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  36. ]
  37. def ref_masked_attention(
  38. query: torch.Tensor,
  39. key: torch.Tensor,
  40. value: torch.Tensor,
  41. scale: float,
  42. attn_mask: Optional[torch.Tensor] = None,
  43. ) -> torch.Tensor:
  44. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  45. if attn_mask is not None:
  46. attn_weights = attn_weights + attn_mask.float()
  47. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  48. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  49. return out
  50. def ref_single_query_cached_kv_attention(
  51. output: torch.Tensor,
  52. query: torch.Tensor,
  53. num_queries_per_kv: int,
  54. key_cache: torch.Tensor,
  55. value_cache: torch.Tensor,
  56. block_tables: torch.Tensor,
  57. seq_lens: torch.Tensor,
  58. scale: float,
  59. alibi_slopes: Optional[torch.Tensor],
  60. ) -> None:
  61. num_query_heads = query.shape[1]
  62. num_kv_heads = value_cache.shape[1]
  63. head_size = value_cache.shape[2]
  64. block_size = value_cache.shape[3]
  65. num_seqs = query.shape[0]
  66. block_tables_lst = block_tables.cpu().tolist()
  67. seq_lens_lst = seq_lens.cpu().tolist()
  68. for i in range(num_seqs):
  69. q = query[i].unsqueeze(0)
  70. block_table = block_tables_lst[i]
  71. seq_len = int(seq_lens_lst[i])
  72. keys_lst: List[torch.Tensor] = []
  73. values_lst: List[torch.Tensor] = []
  74. for j in range(seq_len):
  75. block_number = int(block_table[j // block_size])
  76. block_offset = j % block_size
  77. k = key_cache[block_number, :, :, block_offset, :]
  78. k = k.reshape(num_kv_heads, head_size)
  79. keys_lst.append(k)
  80. v = value_cache[block_number, :, :, block_offset]
  81. values_lst.append(v)
  82. keys = torch.stack(keys_lst, dim=0)
  83. values = torch.stack(values_lst, dim=0)
  84. if num_queries_per_kv > 1:
  85. # Handle MQA and GQA
  86. keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
  87. values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
  88. alibi_bias = None
  89. if alibi_slopes is not None:
  90. # Create the ALiBi bias used in the paged attention kernel.
  91. position_ids = torch.arange(seq_len).int()
  92. alibi_bias = (position_ids - seq_len + 1).float()
  93. alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
  94. 1, 1, -1)
  95. out = ref_masked_attention(q, keys, values, scale, alibi_bias)
  96. out = out.view(num_query_heads, head_size)
  97. output[i].copy_(out, non_blocking=True)
  98. @pytest.mark.parametrize("version", ["v1", "v2"])
  99. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  100. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  101. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  102. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  103. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  104. @pytest.mark.parametrize("dtype", DTYPES)
  105. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  106. @pytest.mark.parametrize("seed", SEEDS)
  107. @pytest.mark.parametrize("device", CUDA_DEVICES)
  108. def test_paged_attention(
  109. kv_cache_factory,
  110. version: str,
  111. num_seqs: int,
  112. num_heads: Tuple[int, int],
  113. head_size: int,
  114. use_alibi: bool,
  115. block_size: int,
  116. dtype: torch.dtype,
  117. kv_cache_dtype: str,
  118. seed: int,
  119. device: str,
  120. ) -> None:
  121. if kv_cache_dtype == "fp8" and head_size % 16:
  122. pytest.skip()
  123. random.seed(seed)
  124. torch.random.manual_seed(seed)
  125. if torch.cuda.is_available():
  126. torch.cuda.manual_seed(seed)
  127. torch.set_default_device(device)
  128. scale = float(1.0 / (head_size**0.5))
  129. num_query_heads, num_kv_heads = num_heads
  130. query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
  131. query.uniform_(-scale, scale)
  132. assert num_query_heads % num_kv_heads == 0
  133. num_queries_per_kv = num_query_heads // num_kv_heads
  134. alibi_slopes = None
  135. if use_alibi:
  136. alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
  137. seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  138. seq_lens[-1] = MAX_SEQ_LEN
  139. max_seq_len = max(seq_lens)
  140. seq_lens = torch.tensor(seq_lens, dtype=torch.int)
  141. # Create the block tables.
  142. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
  143. block_tables_lst: List[List[int]] = []
  144. for _ in range(num_seqs):
  145. block_table = [
  146. random.randint(0, NUM_BLOCKS - 1)
  147. for _ in range(max_num_blocks_per_seq)
  148. ]
  149. block_tables_lst.append(block_table)
  150. block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
  151. # Create the KV caches.
  152. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  153. num_kv_heads, head_size,
  154. kv_cache_dtype, dtype, seed,
  155. device)
  156. key_cache, value_cache = key_caches[0], value_caches[0]
  157. # Using default kv_scale
  158. k_scale = v_scale = 1.0
  159. # Call the paged attention kernel.
  160. output = torch.empty_like(query)
  161. if version == "v1":
  162. ops.paged_attention_v1(
  163. output,
  164. query,
  165. key_cache,
  166. value_cache,
  167. num_kv_heads,
  168. scale,
  169. block_tables,
  170. seq_lens,
  171. block_size,
  172. max_seq_len,
  173. alibi_slopes,
  174. kv_cache_dtype,
  175. k_scale,
  176. v_scale,
  177. )
  178. opcheck(torch.ops._C.paged_attention_v1,
  179. (output, query, key_cache, value_cache, num_kv_heads, scale,
  180. block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
  181. kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
  182. cond=(head_size == HEAD_SIZES[0]))
  183. elif version == "v2":
  184. num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
  185. assert PARTITION_SIZE % block_size == 0
  186. num_seqs, num_heads, head_size = output.shape
  187. tmp_output = torch.empty(
  188. size=(num_seqs, num_heads, num_partitions, head_size),
  189. dtype=output.dtype,
  190. )
  191. exp_sums = torch.empty(
  192. size=(num_seqs, num_heads, num_partitions),
  193. dtype=torch.float32,
  194. )
  195. max_logits = torch.empty_like(exp_sums)
  196. ops.paged_attention_v2(
  197. output,
  198. exp_sums,
  199. max_logits,
  200. tmp_output,
  201. query,
  202. key_cache,
  203. value_cache,
  204. num_kv_heads,
  205. scale,
  206. block_tables,
  207. seq_lens,
  208. block_size,
  209. max_seq_len,
  210. alibi_slopes,
  211. kv_cache_dtype,
  212. k_scale,
  213. v_scale,
  214. )
  215. opcheck(torch.ops._C.paged_attention_v2,
  216. (output, exp_sums, max_logits, tmp_output, query, key_cache,
  217. value_cache, num_kv_heads, scale, block_tables, seq_lens,
  218. block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
  219. k_scale, v_scale, 0, 0, 0, 64, 0),
  220. cond=(head_size == HEAD_SIZES[0]))
  221. else:
  222. raise AssertionError(f"Unknown version: {version}")
  223. # Run the reference implementation.
  224. if kv_cache_dtype == "fp8":
  225. # Convert cache data back to dtype.
  226. x = 16 // torch.tensor([], dtype=dtype).element_size()
  227. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
  228. block_size, x)
  229. dequantized_key_cache = torch.empty(size=key_cache_shape,
  230. dtype=dtype,
  231. device=device)
  232. ops.convert_fp8(dequantized_key_cache, key_cache)
  233. key_cache = dequantized_key_cache
  234. value_cache_shape = value_cache.shape
  235. dequantized_value_cache = torch.empty(size=value_cache_shape,
  236. dtype=dtype,
  237. device=device)
  238. ops.convert_fp8(dequantized_value_cache, value_cache)
  239. value_cache = dequantized_value_cache
  240. ref_output = torch.empty_like(query)
  241. ref_single_query_cached_kv_attention(
  242. ref_output,
  243. query,
  244. num_queries_per_kv,
  245. key_cache,
  246. value_cache,
  247. block_tables,
  248. seq_lens,
  249. scale,
  250. alibi_slopes,
  251. )
  252. # NOTE(woosuk): Due to the kernel-level differences in the two
  253. # implementations, there is a small numerical difference in the two
  254. # outputs. Thus, we use a relaxed tolerance for the test.
  255. atol = get_default_atol(output) if is_hip() else 1e-3
  256. rtol = get_default_rtol(output) if is_hip() else 1e-5
  257. # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
  258. # so we use a relaxed tolerance for the test.
  259. atol, rtol = 1e-3, 1e-5
  260. if kv_cache_dtype == "fp8":
  261. atol, rtol = 1e-2, 1e-5
  262. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
  263. def ref_multi_query_kv_attention(
  264. cu_seq_lens: List[int],
  265. query: torch.Tensor,
  266. key: torch.Tensor,
  267. value: torch.Tensor,
  268. scale: float,
  269. dtype: torch.dtype,
  270. ) -> torch.Tensor:
  271. num_seqs = len(cu_seq_lens) - 1
  272. ref_outputs: List[torch.Tensor] = []
  273. for i in range(num_seqs):
  274. start_idx = cu_seq_lens[i]
  275. end_idx = cu_seq_lens[i + 1]
  276. seq_len = end_idx - start_idx
  277. # Create attention mask.
  278. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
  279. diagonal=1)
  280. attn_mask = attn_mask * torch.finfo(dtype).min
  281. attn_mask = attn_mask.to(dtype=dtype)
  282. ref_output = ref_masked_attention(
  283. query[start_idx:end_idx],
  284. key[start_idx:end_idx],
  285. value[start_idx:end_idx],
  286. scale,
  287. attn_mask=attn_mask,
  288. )
  289. ref_outputs.append(ref_output)
  290. return torch.cat(ref_outputs, dim=0)
  291. @pytest.mark.parametrize("version", ["rocm"])
  292. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  293. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  294. @pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
  295. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  296. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  297. @pytest.mark.parametrize("dtype", DTYPES)
  298. @pytest.mark.parametrize("kv_cache_dtype", ["auto"])
  299. @pytest.mark.parametrize("seed", SEEDS)
  300. @pytest.mark.parametrize("device", CUDA_DEVICES)
  301. @pytest.mark.skipif(not is_hip(), reason="only for rocm")
  302. def test_paged_attention_rocm(
  303. kv_cache_factory,
  304. version: str,
  305. num_seqs: int,
  306. num_heads: Tuple[int, int],
  307. head_size: int,
  308. use_alibi: bool,
  309. block_size: int,
  310. dtype: torch.dtype,
  311. kv_cache_dtype: str,
  312. seed: int,
  313. device: str,
  314. ) -> None:
  315. random.seed(seed)
  316. torch.random.manual_seed(seed)
  317. if torch.cuda.is_available():
  318. torch.cuda.manual_seed(seed)
  319. torch.set_default_device(device)
  320. scale = float(1.0 / (head_size**0.5))
  321. num_query_heads, num_kv_heads = num_heads
  322. query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
  323. query.uniform_(-scale, scale)
  324. assert num_query_heads % num_kv_heads == 0
  325. num_queries_per_kv = num_query_heads // num_kv_heads
  326. alibi_slopes = None
  327. if use_alibi:
  328. alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
  329. context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  330. context_lens[-1] = MAX_SEQ_LEN
  331. #context_lens = [8192 for _ in range(num_seqs)]
  332. max_context_len = max(context_lens)
  333. context_lens = torch.tensor(context_lens, dtype=torch.int)
  334. #print('>>> ctx lens', context_lens)
  335. # Create the block tables.
  336. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
  337. block_tables = []
  338. for _ in range(num_seqs):
  339. block_table = [
  340. random.randint(0, NUM_BLOCKS - 1)
  341. for _ in range(max_num_blocks_per_seq)
  342. ]
  343. block_tables.append(block_table)
  344. block_tables = torch.tensor(block_tables, dtype=torch.int)
  345. # Create the KV caches.
  346. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  347. num_kv_heads, head_size,
  348. kv_cache_dtype, dtype, seed,
  349. device)
  350. key_cache, value_cache = key_caches[0], value_caches[0]
  351. # TODO enable fp8 kv cache
  352. # Using default kv_scale
  353. # kv_scale = 1.0
  354. # Call the paged attention kernel.
  355. output = torch.empty_like(query)
  356. PARTITION_SIZE_ROCM = 256
  357. num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
  358. PARTITION_SIZE_ROCM)
  359. assert PARTITION_SIZE_ROCM % block_size == 0
  360. num_seqs, num_heads, head_size = output.shape
  361. tmp_output = torch.empty(
  362. size=(num_seqs, num_heads, num_partitions, head_size),
  363. dtype=output.dtype,
  364. )
  365. exp_sums = torch.empty(
  366. size=(num_seqs, num_heads, num_partitions),
  367. dtype=torch.float32,
  368. )
  369. max_logits = torch.empty_like(exp_sums)
  370. if version == "rocm":
  371. ops.paged_attention_rocm(
  372. output,
  373. exp_sums,
  374. max_logits,
  375. tmp_output,
  376. query,
  377. key_cache,
  378. value_cache,
  379. num_kv_heads,
  380. scale,
  381. block_tables,
  382. context_lens,
  383. block_size,
  384. max_context_len,
  385. alibi_slopes,
  386. kv_cache_dtype,
  387. )
  388. else:
  389. raise AssertionError(f"Unknown version: {version}")
  390. # Run the reference implementation.
  391. if kv_cache_dtype == "fp8":
  392. # Convert cache data back to dtype.
  393. x = 16 // torch.tensor([], dtype=dtype).element_size()
  394. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
  395. block_size, x)
  396. dequantized_key_cache = torch.empty(size=key_cache_shape,
  397. dtype=dtype,
  398. device=device)
  399. ops.convert_fp8(key_cache, dequantized_key_cache)
  400. key_cache = dequantized_key_cache
  401. value_cache_shape = value_cache.shape
  402. dequantized_value_cache = torch.empty(size=value_cache_shape,
  403. dtype=dtype,
  404. device=device)
  405. ops.convert_fp8(value_cache, dequantized_value_cache)
  406. value_cache = dequantized_value_cache
  407. ref_output = torch.empty_like(query)
  408. ref_single_query_cached_kv_attention(
  409. ref_output,
  410. query,
  411. num_queries_per_kv,
  412. key_cache,
  413. value_cache,
  414. block_tables,
  415. context_lens,
  416. scale,
  417. alibi_slopes,
  418. )
  419. # NOTE: Due to the kernel-level differences in the two
  420. # implementations, there is a small numerical difference in the two
  421. # outputs. Thus, we use a relaxed tolerance for the test.
  422. atol = get_default_atol(output) if is_hip() else 1e-3
  423. rtol = get_default_rtol(output) if is_hip() else 1e-5
  424. # NOTE: FP8 KV Cache will introduce quantization error,
  425. # so we use a relaxed tolerance for the test.
  426. atol, rtol = 1e-4, 1e-5
  427. if dtype == torch.bfloat16:
  428. atol, rtol = 2e-4, 1e-5
  429. if use_alibi:
  430. if dtype == torch.half:
  431. atol, rtol = 5e-4, 1e-5
  432. if dtype == torch.bfloat16:
  433. atol, rtol = 1e-3, 1e-5
  434. if kv_cache_dtype == "fp8":
  435. atol, rtol = 1e-2, 1e-5
  436. assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
  437. # TODO: Add tests for USE_ALIBI=True.
  438. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
  439. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  440. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  441. @pytest.mark.parametrize("dtype", DTYPES)
  442. @pytest.mark.parametrize("seed", SEEDS)
  443. @pytest.mark.parametrize("device", CUDA_DEVICES)
  444. @pytest.mark.skipif(is_hip(), reason="skip for rocm")
  445. @torch.inference_mode()
  446. def test_multi_query_kv_attention(
  447. num_seqs: int,
  448. num_heads: Tuple[int, int],
  449. head_size: int,
  450. dtype: torch.dtype,
  451. seed: int,
  452. device: str,
  453. ) -> None:
  454. random.seed(seed)
  455. torch.random.manual_seed(seed)
  456. if torch.cuda.is_available():
  457. torch.cuda.manual_seed(seed)
  458. torch.set_default_device(device)
  459. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
  460. # As the xformers library is already tested with its own tests, we can use
  461. # a smaller MAX_SEQ_LEN here.
  462. max_len = min(MAX_SEQ_LEN, 4096)
  463. seq_lens = random.sample(range(1, max_len), num_seqs)
  464. num_tokens = sum(seq_lens)
  465. scale = float(1.0 / (head_size**0.5))
  466. num_query_heads, num_kv_heads = num_heads
  467. qkv = torch.empty(num_tokens,
  468. num_query_heads + 2 * num_kv_heads,
  469. head_size,
  470. dtype=dtype)
  471. qkv.uniform_(-scale, scale)
  472. query, key, value = qkv.split(
  473. [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
  474. num_queries_per_kv = num_query_heads // num_kv_heads
  475. if num_queries_per_kv > 1:
  476. # Handle MQA and GQA
  477. key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
  478. value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
  479. attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
  480. output = xops.memory_efficient_attention_forward(
  481. query.unsqueeze(0),
  482. key.unsqueeze(0),
  483. value.unsqueeze(0),
  484. attn_bias=attn_bias,
  485. p=0.0,
  486. scale=scale,
  487. )
  488. output = output.squeeze(0)
  489. cu_seq_lens = [0]
  490. for seq_len in seq_lens:
  491. cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
  492. ref_output = ref_multi_query_kv_attention(
  493. cu_seq_lens,
  494. query,
  495. key,
  496. value,
  497. scale,
  498. dtype,
  499. )
  500. atol = get_default_atol(output) if is_hip() else 1e-3
  501. rtol = get_default_rtol(output) if is_hip() else 1e-5
  502. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)