test_blocksparse_attention.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import random
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.attention.ops.blocksparse_attention.interface import (
  7. LocalStridedBlockSparseAttn)
  8. from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip
  9. from .allclose_default import get_default_atol, get_default_rtol
  10. FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
  11. # This will change depending on the compute capability.
  12. # - 512 as a buffer
  13. MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
  14. # MAX_SEQ_LEN = 2771
  15. # There may not be enough gpu memory due to large NUM_BLOCKS.
  16. # Reduce NUM_BLOCKS when it happens.
  17. NUM_BLOCKS = 4321 # Arbitrary values for testing
  18. PARTITION_SIZE = 512
  19. DTYPES = [torch.half, torch.bfloat16]
  20. NUM_GEN_SEQS = [3] # Arbitrary values for testing
  21. NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
  22. NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
  23. HEAD_SIZES = [64, 112]
  24. BLOCK_SIZES = [16, 32]
  25. USE_ALIBI = [False, True]
  26. KV_CACHE_DTYPE = ["auto", "fp8"]
  27. SEEDS = [0]
  28. CUDA_DEVICES = ['cuda:0']
  29. BLOCKSPARSE_LOCAL_BLOCKS = [16]
  30. BLOCKSPARSE_VERT_STRIDES = [8]
  31. BLOCKSPARSE_BLOCK_SIZES = [64]
  32. BLOCKSPARSE_HEADS_SLIDINGS = [0, 2, -1]
  33. BLOCKSPARSE_HOMO_HEADS = [True, False]
  34. def ref_masked_attention(
  35. query: torch.Tensor,
  36. key: torch.Tensor,
  37. value: torch.Tensor,
  38. scale: float,
  39. attn_mask: Optional[torch.Tensor] = None,
  40. ) -> torch.Tensor:
  41. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  42. if attn_mask is not None:
  43. attn_weights = attn_weights + attn_mask.float()
  44. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  45. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  46. return out
  47. def ref_single_query_cached_kv_attention(
  48. output: torch.Tensor,
  49. query: torch.Tensor,
  50. num_queries_per_kv: int,
  51. key_cache: torch.Tensor,
  52. value_cache: torch.Tensor,
  53. block_tables: torch.Tensor,
  54. seq_lens: torch.Tensor,
  55. scale: float,
  56. alibi_slopes: Optional[torch.Tensor],
  57. tp_rank: int = 0,
  58. blocksparse_local_blocks: int = 0,
  59. blocksparse_vert_stride: int = 1,
  60. blocksparse_block_size: int = 64,
  61. blocksparse_head_sliding_step: int = 0,
  62. ) -> None:
  63. num_query_heads = query.shape[1]
  64. num_kv_heads = value_cache.shape[1]
  65. head_size = value_cache.shape[2]
  66. block_size = value_cache.shape[3]
  67. num_seqs = query.shape[0]
  68. block_tables_lst = block_tables.cpu().tolist()
  69. seq_lens_lst = seq_lens.cpu().tolist()
  70. for i in range(num_seqs):
  71. q = query[i].unsqueeze(0)
  72. block_table = block_tables_lst[i]
  73. seq_len = int(seq_lens_lst[i])
  74. keys_lst: List[torch.Tensor] = []
  75. values_lst: List[torch.Tensor] = []
  76. for j in range(seq_len):
  77. block_number = int(block_table[j // block_size])
  78. block_offset = j % block_size
  79. k = key_cache[block_number, :, :, block_offset, :]
  80. k = k.reshape(num_kv_heads, head_size)
  81. keys_lst.append(k)
  82. v = value_cache[block_number, :, :, block_offset]
  83. values_lst.append(v)
  84. keys = torch.stack(keys_lst, dim=0)
  85. values = torch.stack(values_lst, dim=0)
  86. if num_queries_per_kv > 1:
  87. # Handle MQA and GQA
  88. keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
  89. values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
  90. alibi_bias = None
  91. if alibi_slopes is not None:
  92. # Create the ALiBi bias used in the paged attention kernel.
  93. position_ids = torch.arange(seq_len).int()
  94. alibi_bias = (position_ids - seq_len + 1).float()
  95. alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
  96. 1, 1, -1)
  97. if blocksparse_vert_stride >= 1:
  98. bsize = blocksparse_block_size
  99. hsliding = blocksparse_head_sliding_step
  100. vert = blocksparse_vert_stride
  101. locals = blocksparse_local_blocks
  102. qb = (seq_len - 1) // bsize
  103. attn_mask = q.new_zeros(
  104. (num_query_heads, 1, seq_len)).float() - torch.inf
  105. for h in range(num_query_heads):
  106. if hsliding >= 0: # slide with q heads
  107. bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1
  108. else: # slide with kv heads
  109. bs_offset = (tp_rank * num_kv_heads +
  110. h // num_queries_per_kv) * (-hsliding) + 1
  111. for kb in range(qb + 1):
  112. kj = kb * bsize
  113. if (qb - kb) < locals or \
  114. (kb + bs_offset) % vert == 0:
  115. attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0
  116. if alibi_bias is not None:
  117. attn_mask += alibi_bias
  118. else:
  119. attn_mask = alibi_bias
  120. out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask)
  121. out = out.view(num_query_heads, head_size)
  122. output[i].copy_(out, non_blocking=True)
  123. @pytest.mark.parametrize("version", ["v1", "v2"])
  124. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  125. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  126. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  127. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  128. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  129. @pytest.mark.parametrize("dtype", DTYPES)
  130. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  131. @pytest.mark.parametrize("seed", SEEDS)
  132. @pytest.mark.parametrize("device", CUDA_DEVICES)
  133. @pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
  134. @pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
  135. @pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
  136. @pytest.mark.parametrize("blocksparse_head_sliding_step",
  137. BLOCKSPARSE_HEADS_SLIDINGS)
  138. def test_paged_attention(
  139. kv_cache_factory,
  140. version: str,
  141. num_seqs: int,
  142. num_heads: Tuple[int, int],
  143. head_size: int,
  144. use_alibi: bool,
  145. block_size: int,
  146. dtype: torch.dtype,
  147. kv_cache_dtype: str,
  148. seed: int,
  149. device: str,
  150. blocksparse_local_blocks: int,
  151. blocksparse_vert_stride: int,
  152. blocksparse_block_size: int,
  153. blocksparse_head_sliding_step: int,
  154. ) -> None:
  155. random.seed(seed)
  156. torch.random.manual_seed(seed)
  157. if torch.cuda.is_available():
  158. torch.cuda.manual_seed(seed)
  159. torch.set_default_device(device)
  160. scale = float(1.0 / (head_size**0.5))
  161. num_query_heads, num_kv_heads = num_heads
  162. query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
  163. query.uniform_(-scale, scale)
  164. assert num_query_heads % num_kv_heads == 0
  165. num_queries_per_kv = num_query_heads // num_kv_heads
  166. alibi_slopes = None
  167. if use_alibi:
  168. alibi_slopes = torch.rand(num_query_heads, dtype=torch.float)
  169. seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  170. seq_lens[-1] = MAX_SEQ_LEN
  171. max_seq_len = max(seq_lens)
  172. seq_lens = torch.tensor(seq_lens, dtype=torch.int)
  173. # Create the block tables.
  174. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
  175. block_tables = []
  176. for _ in range(num_seqs):
  177. block_table = [
  178. random.randint(0, NUM_BLOCKS - 1)
  179. for _ in range(max_num_blocks_per_seq)
  180. ]
  181. block_tables.append(block_table)
  182. block_tables = torch.tensor(block_tables, dtype=torch.int)
  183. # Create the KV caches.
  184. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  185. num_kv_heads, head_size,
  186. kv_cache_dtype, dtype, seed,
  187. device)
  188. key_cache, value_cache = key_caches[0], value_caches[0]
  189. # Using default kv_scale
  190. k_scale = v_scale = 1.0
  191. tp_rank = 0
  192. # Call the paged attention kernel.
  193. output = torch.empty_like(query)
  194. if version == "v1":
  195. ops.paged_attention_v1(
  196. output,
  197. query,
  198. key_cache,
  199. value_cache,
  200. num_kv_heads,
  201. scale,
  202. block_tables,
  203. seq_lens,
  204. block_size,
  205. max_seq_len,
  206. alibi_slopes,
  207. kv_cache_dtype,
  208. k_scale,
  209. v_scale,
  210. tp_rank=tp_rank,
  211. blocksparse_local_blocks=blocksparse_local_blocks,
  212. blocksparse_vert_stride=blocksparse_vert_stride,
  213. blocksparse_block_size=blocksparse_block_size,
  214. blocksparse_head_sliding_step=blocksparse_head_sliding_step,
  215. )
  216. elif version == "v2":
  217. num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
  218. assert PARTITION_SIZE % block_size == 0
  219. num_seqs, num_heads, head_size = output.shape
  220. tmp_output = torch.empty(
  221. size=(num_seqs, num_heads, num_partitions, head_size),
  222. dtype=output.dtype,
  223. )
  224. exp_sums = torch.empty(
  225. size=(num_seqs, num_heads, num_partitions),
  226. dtype=torch.float32,
  227. )
  228. max_logits = torch.empty_like(exp_sums)
  229. ops.paged_attention_v2(
  230. output,
  231. exp_sums,
  232. max_logits,
  233. tmp_output,
  234. query,
  235. key_cache,
  236. value_cache,
  237. num_kv_heads,
  238. scale,
  239. block_tables,
  240. seq_lens,
  241. block_size,
  242. max_seq_len,
  243. alibi_slopes,
  244. kv_cache_dtype,
  245. k_scale,
  246. v_scale,
  247. tp_rank=tp_rank,
  248. blocksparse_local_blocks=blocksparse_local_blocks,
  249. blocksparse_vert_stride=blocksparse_vert_stride,
  250. blocksparse_block_size=blocksparse_block_size,
  251. blocksparse_head_sliding_step=blocksparse_head_sliding_step,
  252. )
  253. else:
  254. raise AssertionError(f"Unknown version: {version}")
  255. # Run the reference implementation.
  256. if kv_cache_dtype == "fp8":
  257. # Convert cache data back to dtype.
  258. x = 16 // torch.tensor([], dtype=dtype).element_size()
  259. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
  260. block_size, x)
  261. dequantized_key_cache = torch.empty(size=key_cache_shape,
  262. dtype=dtype,
  263. device=device)
  264. ops.convert_fp8(dequantized_key_cache, key_cache)
  265. key_cache = dequantized_key_cache
  266. value_cache_shape = value_cache.shape
  267. dequantized_value_cache = torch.empty(size=value_cache_shape,
  268. dtype=dtype,
  269. device=device)
  270. ops.convert_fp8(dequantized_value_cache, value_cache)
  271. value_cache = dequantized_value_cache
  272. ref_output = torch.empty_like(query)
  273. ref_single_query_cached_kv_attention(
  274. ref_output,
  275. query,
  276. num_queries_per_kv,
  277. key_cache,
  278. value_cache,
  279. block_tables,
  280. seq_lens,
  281. scale,
  282. alibi_slopes,
  283. tp_rank,
  284. blocksparse_local_blocks,
  285. blocksparse_vert_stride,
  286. blocksparse_block_size,
  287. blocksparse_head_sliding_step,
  288. )
  289. # NOTE: Due to the kernel-level differences in the two
  290. # implementations, there is a small numerical difference in the two
  291. # outputs. Thus, we use a relaxed tolerance for the test.
  292. atol = get_default_atol(output) if is_hip() else 1e-3
  293. rtol = get_default_rtol(output) if is_hip() else 1e-5
  294. # NOTE: FP8 KV Cache will introduce quantization error,
  295. # so we use a relaxed tolerance for the test.
  296. atol, rtol = 1e-3, 1e-5
  297. if kv_cache_dtype == "fp8":
  298. atol, rtol = 1e-2, 1e-5
  299. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
  300. def ref_multi_query_kv_attention(
  301. cu_seq_lens: List[int],
  302. query: torch.Tensor,
  303. key: torch.Tensor,
  304. value: torch.Tensor,
  305. scale: float,
  306. dtype: torch.dtype,
  307. ) -> torch.Tensor:
  308. num_seqs = len(cu_seq_lens) - 1
  309. ref_outputs = []
  310. for i in range(num_seqs):
  311. start_idx = cu_seq_lens[i]
  312. end_idx = cu_seq_lens[i + 1]
  313. seq_len = end_idx - start_idx
  314. # Create attention mask.
  315. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
  316. diagonal=1)
  317. attn_mask = attn_mask * torch.finfo(dtype).min
  318. attn_mask = attn_mask.to(dtype=dtype)
  319. ref_output = ref_masked_attention(
  320. query[start_idx:end_idx],
  321. key[start_idx:end_idx],
  322. value[start_idx:end_idx],
  323. scale,
  324. attn_mask=attn_mask,
  325. )
  326. ref_outputs.append(ref_output)
  327. ref_output = torch.cat(ref_outputs, dim=0)
  328. return ref_output
  329. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
  330. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  331. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  332. @pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
  333. @pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
  334. @pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
  335. @pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
  336. @pytest.mark.parametrize("dtype", DTYPES)
  337. @pytest.mark.parametrize("seed", SEEDS)
  338. @pytest.mark.parametrize("device", CUDA_DEVICES)
  339. @torch.inference_mode()
  340. def test_varlen_blocksparse_attention_prefill(
  341. num_seqs: int,
  342. num_heads: Tuple[int, int],
  343. head_size: int,
  344. blocksparse_local_blocks: int,
  345. blocksparse_vert_stride: int,
  346. blocksparse_block_size: int,
  347. blocksparse_homo_heads: bool,
  348. dtype: torch.dtype,
  349. seed: int,
  350. device: str,
  351. ) -> None:
  352. random.seed(seed)
  353. torch.random.manual_seed(seed)
  354. if torch.cuda.is_available():
  355. torch.cuda.manual_seed(seed)
  356. torch.set_default_device(device)
  357. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
  358. # As the xformers library is already tested with its own tests, we can use
  359. # a smaller MAX_SEQ_LEN here.
  360. max_len = min(MAX_SEQ_LEN, 4096)
  361. seq_lens = random.sample(range(1, max_len), num_seqs)
  362. cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0)
  363. num_tokens = sum(seq_lens)
  364. scale = float(1.0 / (head_size**0.5))
  365. num_query_heads, num_kv_heads = num_heads
  366. assert num_query_heads % num_kv_heads == 0
  367. num_queries_per_kv = num_query_heads // num_kv_heads
  368. qkv = torch.empty(num_tokens,
  369. num_query_heads + 2 * num_kv_heads,
  370. head_size,
  371. dtype=dtype)
  372. qkv.uniform_(-scale, scale)
  373. query, key, value = qkv.split(
  374. [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
  375. bs_attn_op = LocalStridedBlockSparseAttn(
  376. num_query_heads,
  377. max_len,
  378. local_blocks=blocksparse_local_blocks,
  379. vert_stride=blocksparse_vert_stride,
  380. block_size=blocksparse_block_size,
  381. device=device,
  382. dtype=dtype,
  383. homo_head=blocksparse_homo_heads)
  384. output = bs_attn_op(query,
  385. key,
  386. value,
  387. cu_seq_lens.to(device),
  388. sm_scale=scale)
  389. if num_queries_per_kv > 1:
  390. # Handle MQA and GQA
  391. key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
  392. value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
  393. ref_output = ref_multi_query_kv_attention(
  394. cu_seq_lens.tolist(),
  395. query,
  396. key,
  397. value,
  398. scale,
  399. dtype,
  400. )
  401. torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)