test_attention.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. import random
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip
  9. from .allclose_default import get_default_atol, get_default_rtol
  10. FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
  11. # This will change depending on the compute capability.
  12. # - 512 as a buffer
  13. MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
  14. # There may not be enough gpu memory due to large NUM_BLOCKS.
  15. # Reduce NUM_BLOCKS when it happens.
  16. NUM_BLOCKS = 4321 # Arbitrary values for testing
  17. PARTITION_SIZE = 512
  18. # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
  19. DTYPES = [torch.half, torch.bfloat16, torch.float
  20. ] if not is_hip() else [torch.half, torch.bfloat16]
  21. NUM_GEN_SEQS = [7] # Arbitrary values for testing
  22. NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
  23. NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
  24. # FlashAttention forward only supports head dimension at most 128
  25. # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
  26. HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
  27. ] if not is_hip() else [64, 80, 96, 112, 128]
  28. BLOCK_SIZES = [16, 32]
  29. USE_ALIBI = [False, True]
  30. KV_CACHE_DTYPE = ["auto", "fp8"]
  31. SEEDS = [0]
  32. CUDA_DEVICES = [
  33. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  34. ]
  35. def ref_masked_attention(
  36. query: torch.Tensor,
  37. key: torch.Tensor,
  38. value: torch.Tensor,
  39. scale: float,
  40. attn_mask: Optional[torch.Tensor] = None,
  41. ) -> torch.Tensor:
  42. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  43. if attn_mask is not None:
  44. attn_weights = attn_weights + attn_mask.float()
  45. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  46. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  47. return out
  48. def ref_single_query_cached_kv_attention(
  49. output: torch.Tensor,
  50. query: torch.Tensor,
  51. num_queries_per_kv: int,
  52. key_cache: torch.Tensor,
  53. value_cache: torch.Tensor,
  54. block_tables: torch.Tensor,
  55. seq_lens: torch.Tensor,
  56. scale: float,
  57. alibi_slopes: Optional[torch.Tensor],
  58. ) -> None:
  59. num_query_heads = query.shape[1]
  60. num_kv_heads = value_cache.shape[1]
  61. head_size = value_cache.shape[2]
  62. block_size = value_cache.shape[3]
  63. num_seqs = query.shape[0]
  64. block_tables_lst = block_tables.cpu().tolist()
  65. seq_lens_lst = seq_lens.cpu().tolist()
  66. for i in range(num_seqs):
  67. q = query[i].unsqueeze(0)
  68. block_table = block_tables_lst[i]
  69. seq_len = int(seq_lens_lst[i])
  70. keys_lst: List[torch.Tensor] = []
  71. values_lst: List[torch.Tensor] = []
  72. for j in range(seq_len):
  73. block_number = int(block_table[j // block_size])
  74. block_offset = j % block_size
  75. k = key_cache[block_number, :, :, block_offset, :]
  76. k = k.reshape(num_kv_heads, head_size)
  77. keys_lst.append(k)
  78. v = value_cache[block_number, :, :, block_offset]
  79. values_lst.append(v)
  80. keys = torch.stack(keys_lst, dim=0)
  81. values = torch.stack(values_lst, dim=0)
  82. if num_queries_per_kv > 1:
  83. # Handle MQA and GQA
  84. keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
  85. values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
  86. alibi_bias = None
  87. if alibi_slopes is not None:
  88. # Create the ALiBi bias used in the paged attention kernel.
  89. position_ids = torch.arange(seq_len).int()
  90. alibi_bias = (position_ids - seq_len + 1).float()
  91. alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
  92. 1, 1, -1)
  93. out = ref_masked_attention(q, keys, values, scale, alibi_bias)
  94. out = out.view(num_query_heads, head_size)
  95. output[i].copy_(out, non_blocking=True)
  96. @pytest.mark.parametrize("version", ["v1", "v2"])
  97. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  98. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  99. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  100. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  101. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  102. @pytest.mark.parametrize("dtype", DTYPES)
  103. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  104. @pytest.mark.parametrize("seed", SEEDS)
  105. @pytest.mark.parametrize("device", CUDA_DEVICES)
  106. def test_paged_attention(
  107. kv_cache_factory,
  108. version: str,
  109. num_seqs: int,
  110. num_heads: Tuple[int, int],
  111. head_size: int,
  112. use_alibi: bool,
  113. block_size: int,
  114. dtype: torch.dtype,
  115. kv_cache_dtype: str,
  116. seed: int,
  117. device: str,
  118. ) -> None:
  119. if kv_cache_dtype == "fp8" and head_size % 16:
  120. pytest.skip()
  121. random.seed(seed)
  122. torch.random.manual_seed(seed)
  123. if torch.cuda.is_available():
  124. torch.cuda.manual_seed(seed)
  125. torch.set_default_device(device)
  126. scale = float(1.0 / (head_size**0.5))
  127. num_query_heads, num_kv_heads = num_heads
  128. query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
  129. query.uniform_(-scale, scale)
  130. assert num_query_heads % num_kv_heads == 0
  131. num_queries_per_kv = num_query_heads // num_kv_heads
  132. alibi_slopes = None
  133. if use_alibi:
  134. alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
  135. seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  136. seq_lens[-1] = MAX_SEQ_LEN
  137. max_seq_len = max(seq_lens)
  138. seq_lens = torch.tensor(seq_lens, dtype=torch.int)
  139. # Create the block tables.
  140. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
  141. block_tables_lst: List[List[int]] = []
  142. for _ in range(num_seqs):
  143. block_table = [
  144. random.randint(0, NUM_BLOCKS - 1)
  145. for _ in range(max_num_blocks_per_seq)
  146. ]
  147. block_tables_lst.append(block_table)
  148. block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
  149. # Create the KV caches.
  150. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  151. num_kv_heads, head_size,
  152. kv_cache_dtype, dtype, seed,
  153. device)
  154. key_cache, value_cache = key_caches[0], value_caches[0]
  155. # Using default kv_scale
  156. k_scale = v_scale = 1.0
  157. # Call the paged attention kernel.
  158. output = torch.empty_like(query)
  159. if version == "v1":
  160. ops.paged_attention_v1(
  161. output,
  162. query,
  163. key_cache,
  164. value_cache,
  165. num_kv_heads,
  166. scale,
  167. block_tables,
  168. seq_lens,
  169. block_size,
  170. max_seq_len,
  171. alibi_slopes,
  172. kv_cache_dtype,
  173. k_scale,
  174. v_scale,
  175. )
  176. elif version == "v2":
  177. num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
  178. assert PARTITION_SIZE % block_size == 0
  179. num_seqs, num_heads, head_size = output.shape
  180. tmp_output = torch.empty(
  181. size=(num_seqs, num_heads, num_partitions, head_size),
  182. dtype=output.dtype,
  183. )
  184. exp_sums = torch.empty(
  185. size=(num_seqs, num_heads, num_partitions),
  186. dtype=torch.float32,
  187. )
  188. max_logits = torch.empty_like(exp_sums)
  189. ops.paged_attention_v2(
  190. output,
  191. exp_sums,
  192. max_logits,
  193. tmp_output,
  194. query,
  195. key_cache,
  196. value_cache,
  197. num_kv_heads,
  198. scale,
  199. block_tables,
  200. seq_lens,
  201. block_size,
  202. max_seq_len,
  203. alibi_slopes,
  204. kv_cache_dtype,
  205. k_scale,
  206. v_scale,
  207. )
  208. else:
  209. raise AssertionError(f"Unknown version: {version}")
  210. # Run the reference implementation.
  211. if kv_cache_dtype == "fp8":
  212. # Convert cache data back to dtype.
  213. x = 16 // torch.tensor([], dtype=dtype).element_size()
  214. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
  215. block_size, x)
  216. dequantized_key_cache = torch.empty(size=key_cache_shape,
  217. dtype=dtype,
  218. device=device)
  219. ops.convert_fp8(dequantized_key_cache, key_cache)
  220. key_cache = dequantized_key_cache
  221. value_cache_shape = value_cache.shape
  222. dequantized_value_cache = torch.empty(size=value_cache_shape,
  223. dtype=dtype,
  224. device=device)
  225. ops.convert_fp8(dequantized_value_cache, value_cache)
  226. value_cache = dequantized_value_cache
  227. ref_output = torch.empty_like(query)
  228. ref_single_query_cached_kv_attention(
  229. ref_output,
  230. query,
  231. num_queries_per_kv,
  232. key_cache,
  233. value_cache,
  234. block_tables,
  235. seq_lens,
  236. scale,
  237. alibi_slopes,
  238. )
  239. # NOTE(woosuk): Due to the kernel-level differences in the two
  240. # implementations, there is a small numerical difference in the two
  241. # outputs. Thus, we use a relaxed tolerance for the test.
  242. atol = get_default_atol(output) if is_hip() else 1e-3
  243. rtol = get_default_rtol(output) if is_hip() else 1e-5
  244. # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
  245. # so we use a relaxed tolerance for the test.
  246. atol, rtol = 1e-3, 1e-5
  247. if kv_cache_dtype == "fp8":
  248. atol, rtol = 1e-2, 1e-5
  249. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
  250. def ref_multi_query_kv_attention(
  251. cu_seq_lens: List[int],
  252. query: torch.Tensor,
  253. key: torch.Tensor,
  254. value: torch.Tensor,
  255. scale: float,
  256. dtype: torch.dtype,
  257. ) -> torch.Tensor:
  258. num_seqs = len(cu_seq_lens) - 1
  259. ref_outputs: List[torch.Tensor] = []
  260. for i in range(num_seqs):
  261. start_idx = cu_seq_lens[i]
  262. end_idx = cu_seq_lens[i + 1]
  263. seq_len = end_idx - start_idx
  264. # Create attention mask.
  265. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
  266. diagonal=1)
  267. attn_mask = attn_mask * torch.finfo(dtype).min
  268. attn_mask = attn_mask.to(dtype=dtype)
  269. ref_output = ref_masked_attention(
  270. query[start_idx:end_idx],
  271. key[start_idx:end_idx],
  272. value[start_idx:end_idx],
  273. scale,
  274. attn_mask=attn_mask,
  275. )
  276. ref_outputs.append(ref_output)
  277. return torch.cat(ref_outputs, dim=0)
  278. # TODO(woosuk): Add tests for USE_ALIBI=True.
  279. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
  280. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  281. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  282. @pytest.mark.parametrize("dtype", DTYPES)
  283. @pytest.mark.parametrize("seed", SEEDS)
  284. @pytest.mark.parametrize("device", CUDA_DEVICES)
  285. @torch.inference_mode()
  286. def test_multi_query_kv_attention(
  287. num_seqs: int,
  288. num_heads: Tuple[int, int],
  289. head_size: int,
  290. dtype: torch.dtype,
  291. seed: int,
  292. device: str,
  293. ) -> None:
  294. random.seed(seed)
  295. torch.random.manual_seed(seed)
  296. if torch.cuda.is_available():
  297. torch.cuda.manual_seed(seed)
  298. torch.set_default_device(device)
  299. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
  300. # As the xformers library is already tested with its own tests, we can use
  301. # a smaller MAX_SEQ_LEN here.
  302. max_len = min(MAX_SEQ_LEN, 4096)
  303. seq_lens = random.sample(range(1, max_len), num_seqs)
  304. num_tokens = sum(seq_lens)
  305. scale = float(1.0 / (head_size**0.5))
  306. num_query_heads, num_kv_heads = num_heads
  307. qkv = torch.empty(num_tokens,
  308. num_query_heads + 2 * num_kv_heads,
  309. head_size,
  310. dtype=dtype)
  311. qkv.uniform_(-scale, scale)
  312. query, key, value = qkv.split(
  313. [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
  314. num_queries_per_kv = num_query_heads // num_kv_heads
  315. if num_queries_per_kv > 1:
  316. # Handle MQA and GQA
  317. key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
  318. value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
  319. attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
  320. output = xops.memory_efficient_attention_forward(
  321. query.unsqueeze(0),
  322. key.unsqueeze(0),
  323. value.unsqueeze(0),
  324. attn_bias=attn_bias,
  325. p=0.0,
  326. scale=scale,
  327. )
  328. output = output.squeeze(0)
  329. cu_seq_lens = [0]
  330. for seq_len in seq_lens:
  331. cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
  332. ref_output = ref_multi_query_kv_attention(
  333. cu_seq_lens,
  334. query,
  335. key,
  336. value,
  337. scale,
  338. dtype,
  339. )
  340. atol = get_default_atol(output) if is_hip() else 1e-3
  341. rtol = get_default_rtol(output) if is_hip() else 1e-5
  342. torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)