test_attention.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import random
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from allclose_default import get_default_atol, get_default_rtol
  6. from xformers import ops as xops
  7. from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
  8. from aphrodite._C import ops
  9. from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip
  10. FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
  11. # This will change depending on the compute capability.
  12. # - 512 as a buffer
  13. MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
  14. # There may not be enough gpu memory due to large NUM_BLOCKS.
  15. # Reduce NUM_BLOCKS when it happens.
  16. NUM_BLOCKS = 4321 # Arbitrary values for testing
  17. PARTITION_SIZE = 512
  18. # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
  19. DTYPES = [torch.half, torch.bfloat16, torch.float
  20. ] if not is_hip() else [torch.half, torch.bfloat16]
  21. NUM_GEN_SEQS = [7] # Arbitrary values for testing
  22. NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
  23. NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
  24. # FlashAttention forward only supports head dimension at most 128
  25. # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
  26. HEAD_SIZES = [64, 80, 96, 112, 128, 256
  27. ] if not is_hip() else [64, 80, 96, 112, 128]
  28. BLOCK_SIZES = [16, 32]
  29. USE_ALIBI = [False, True]
  30. KV_CACHE_DTYPE = ["auto", "fp8"]
  31. SEEDS = [0]
  32. CUDA_DEVICES = [
  33. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  34. ]
  35. def ref_masked_attention(
  36. query: torch.Tensor,
  37. key: torch.Tensor,
  38. value: torch.Tensor,
  39. scale: float,
  40. attn_mask: Optional[torch.Tensor] = None,
  41. ) -> torch.Tensor:
  42. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  43. if attn_mask is not None:
  44. attn_weights = attn_weights + attn_mask.float()
  45. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  46. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  47. return out
  48. def ref_single_query_cached_kv_attention(
  49. output: torch.Tensor,
  50. query: torch.Tensor,
  51. num_queries_per_kv: int,
  52. key_cache: torch.Tensor,
  53. value_cache: torch.Tensor,
  54. block_tables: torch.Tensor,
  55. context_lens: torch.Tensor,
  56. scale: float,
  57. alibi_slopes: Optional[torch.Tensor],
  58. ) -> None:
  59. num_query_heads = query.shape[1]
  60. num_kv_heads = value_cache.shape[1]
  61. head_size = value_cache.shape[2]
  62. block_size = value_cache.shape[3]
  63. num_seqs = query.shape[0]
  64. block_tables = block_tables.cpu().tolist()
  65. context_lens = context_lens.cpu().tolist()
  66. for i in range(num_seqs):
  67. q = query[i].unsqueeze(0)
  68. block_table = block_tables[i]
  69. context_len = int(context_lens[i])
  70. keys = []
  71. values = []
  72. for j in range(context_len):
  73. block_number = int(block_table[j // block_size])
  74. block_offset = j % block_size
  75. k = key_cache[block_number, :, :, block_offset, :]
  76. k = k.reshape(num_kv_heads, head_size)
  77. keys.append(k)
  78. v = value_cache[block_number, :, :, block_offset]
  79. values.append(v)
  80. keys = torch.stack(keys, dim=0)
  81. values = torch.stack(values, dim=0)
  82. if num_queries_per_kv > 1:
  83. # Handle MQA and GQA
  84. keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
  85. values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
  86. alibi_bias = None
  87. if alibi_slopes is not None:
  88. # Create the ALiBi bias used in the paged attention kernel.
  89. position_ids = torch.arange(context_len).int()
  90. alibi_bias = (position_ids - context_len + 1).float()
  91. alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
  92. 1, 1, -1)
  93. out = ref_masked_attention(q, keys, values, scale, alibi_bias)
  94. out = out.view(num_query_heads, head_size)
  95. output[i].copy_(out, non_blocking=True)
  96. @pytest.mark.parametrize("version", ["v1", "v2"])
  97. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  98. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  99. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  100. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  101. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  102. @pytest.mark.parametrize("dtype", DTYPES)
  103. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  104. @pytest.mark.parametrize("seed", SEEDS)
  105. @pytest.mark.parametrize("device", CUDA_DEVICES)
  106. def test_paged_attention(
  107. kv_cache_factory,
  108. version: str,
  109. num_seqs: int,
  110. num_heads: Tuple[int, int],
  111. head_size: int,
  112. use_alibi: bool,
  113. block_size: int,
  114. dtype: torch.dtype,
  115. kv_cache_dtype: str,
  116. seed: int,
  117. device: str,
  118. ) -> None:
  119. random.seed(seed)
  120. torch.random.manual_seed(seed)
  121. if torch.cuda.is_available():
  122. torch.cuda.manual_seed(seed)
  123. torch.set_default_device(device)
  124. scale = float(1.0 / (head_size**0.5))
  125. num_query_heads, num_kv_heads = num_heads
  126. query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
  127. query.uniform_(-scale, scale)
  128. assert num_query_heads % num_kv_heads == 0
  129. num_queries_per_kv = num_query_heads // num_kv_heads
  130. alibi_slopes = None
  131. if use_alibi:
  132. alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
  133. context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  134. context_lens[-1] = MAX_SEQ_LEN
  135. max_context_len = max(context_lens)
  136. context_lens = torch.tensor(context_lens, dtype=torch.int)
  137. # Create the block tables.
  138. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
  139. block_tables = []
  140. for _ in range(num_seqs):
  141. block_table = [
  142. random.randint(0, NUM_BLOCKS - 1)
  143. for _ in range(max_num_blocks_per_seq)
  144. ]
  145. block_tables.append(block_table)
  146. block_tables = torch.tensor(block_tables, dtype=torch.int)
  147. # Create the KV caches.
  148. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  149. num_kv_heads, head_size,
  150. kv_cache_dtype, dtype, seed,
  151. device)
  152. key_cache, value_cache = key_caches[0], value_caches[0]
  153. # Using default kv_scale
  154. kv_scale = 1.0
  155. # Call the paged attention kernel.
  156. output = torch.empty_like(query)
  157. if version == "v1":
  158. ops.paged_attention_v1(
  159. output,
  160. query,
  161. key_cache,
  162. value_cache,
  163. num_kv_heads,
  164. scale,
  165. block_tables,
  166. context_lens,
  167. block_size,
  168. max_context_len,
  169. alibi_slopes,
  170. kv_cache_dtype,
  171. kv_scale,
  172. )
  173. elif version == "v2":
  174. num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
  175. PARTITION_SIZE)
  176. assert PARTITION_SIZE % block_size == 0
  177. num_seqs, num_heads, head_size = output.shape
  178. tmp_output = torch.empty(
  179. size=(num_seqs, num_heads, num_partitions, head_size),
  180. dtype=output.dtype,
  181. )
  182. exp_sums = torch.empty(
  183. size=(num_seqs, num_heads, num_partitions),
  184. dtype=torch.float32,
  185. )
  186. max_logits = torch.empty_like(exp_sums)
  187. ops.paged_attention_v2(
  188. output,
  189. exp_sums,
  190. max_logits,
  191. tmp_output,
  192. query,
  193. key_cache,
  194. value_cache,
  195. num_kv_heads,
  196. scale,
  197. block_tables,
  198. context_lens,
  199. block_size,
  200. max_context_len,
  201. alibi_slopes,
  202. kv_cache_dtype,
  203. kv_scale,
  204. )
  205. else:
  206. raise AssertionError(f"Unknown version: {version}")
  207. # Run the reference implementation.
  208. if kv_cache_dtype == "fp8":
  209. # Convert cache data back to dtype.
  210. x = 16 // torch.tensor([], dtype=dtype).element_size()
  211. key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
  212. block_size, x)
  213. dequantized_key_cache = torch.empty(size=key_cache_shape,
  214. dtype=dtype,
  215. device=device)
  216. ops.convert_fp8(key_cache, dequantized_key_cache)
  217. key_cache = dequantized_key_cache
  218. value_cache_shape = value_cache.shape
  219. dequantized_value_cache = torch.empty(size=value_cache_shape,
  220. dtype=dtype,
  221. device=device)
  222. ops.convert_fp8(value_cache, dequantized_value_cache)
  223. value_cache = dequantized_value_cache
  224. ref_output = torch.empty_like(query)
  225. ref_single_query_cached_kv_attention(
  226. ref_output,
  227. query,
  228. num_queries_per_kv,
  229. key_cache,
  230. value_cache,
  231. block_tables,
  232. context_lens,
  233. scale,
  234. alibi_slopes,
  235. )
  236. # NOTE: Due to the kernel-level differences in the two
  237. # implementations, there is a small numerical difference in the two
  238. # outputs. Thus, we use a relaxed tolerance for the test.
  239. atol = get_default_atol(output) if is_hip() else 1e-3
  240. rtol = get_default_rtol(output) if is_hip() else 1e-5
  241. # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
  242. # so we use a relaxed tolerance for the test.
  243. atol, rtol = 1e-3, 1e-5
  244. if kv_cache_dtype == "fp8":
  245. atol, rtol = 1e-2, 1e-5
  246. assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
  247. def ref_multi_query_kv_attention(
  248. cu_seq_lens: List[int],
  249. query: torch.Tensor,
  250. key: torch.Tensor,
  251. value: torch.Tensor,
  252. scale: float,
  253. dtype: torch.dtype,
  254. ) -> torch.Tensor:
  255. num_seqs = len(cu_seq_lens) - 1
  256. ref_outputs = []
  257. for i in range(num_seqs):
  258. start_idx = cu_seq_lens[i]
  259. end_idx = cu_seq_lens[i + 1]
  260. seq_len = end_idx - start_idx
  261. # Create attention mask.
  262. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
  263. diagonal=1)
  264. attn_mask = attn_mask * torch.finfo(dtype).min
  265. attn_mask = attn_mask.to(dtype=dtype)
  266. ref_output = ref_masked_attention(
  267. query[start_idx:end_idx],
  268. key[start_idx:end_idx],
  269. value[start_idx:end_idx],
  270. scale,
  271. attn_mask=attn_mask,
  272. )
  273. ref_outputs.append(ref_output)
  274. ref_output = torch.cat(ref_outputs, dim=0)
  275. return ref_output
  276. # TODO: Add tests for USE_ALIBI=True.
  277. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
  278. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  279. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  280. @pytest.mark.parametrize("dtype", DTYPES)
  281. @pytest.mark.parametrize("seed", SEEDS)
  282. @pytest.mark.parametrize("device", CUDA_DEVICES)
  283. @torch.inference_mode()
  284. def test_multi_query_kv_attention(
  285. num_seqs: int,
  286. num_heads: Tuple[int, int],
  287. head_size: int,
  288. dtype: torch.dtype,
  289. seed: int,
  290. device: str,
  291. ) -> None:
  292. random.seed(seed)
  293. torch.random.manual_seed(seed)
  294. if torch.cuda.is_available():
  295. torch.cuda.manual_seed(seed)
  296. torch.set_default_device(device)
  297. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
  298. # As the xformers library is already tested with its own tests, we can use
  299. # a smaller MAX_SEQ_LEN here.
  300. max_len = min(MAX_SEQ_LEN, 4096)
  301. seq_lens = random.sample(range(1, max_len), num_seqs)
  302. num_tokens = sum(seq_lens)
  303. scale = float(1.0 / (head_size**0.5))
  304. num_query_heads, num_kv_heads = num_heads
  305. qkv = torch.empty(num_tokens,
  306. num_query_heads + 2 * num_kv_heads,
  307. head_size,
  308. dtype=dtype)
  309. qkv.uniform_(-scale, scale)
  310. query, key, value = qkv.split(
  311. [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
  312. num_queries_per_kv = num_query_heads // num_kv_heads
  313. if num_queries_per_kv > 1:
  314. # Handle MQA and GQA
  315. key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
  316. value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
  317. attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
  318. output = xops.memory_efficient_attention_forward(
  319. query.unsqueeze(0),
  320. key.unsqueeze(0),
  321. value.unsqueeze(0),
  322. attn_bias=attn_bias,
  323. p=0.0,
  324. scale=scale,
  325. )
  326. output = output.squeeze(0)
  327. cu_seq_lens = [0]
  328. for seq_len in seq_lens:
  329. cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
  330. ref_output = ref_multi_query_kv_attention(
  331. cu_seq_lens,
  332. query,
  333. key,
  334. value,
  335. scale,
  336. dtype,
  337. )
  338. atol = get_default_atol(output) if is_hip() else 1e-3
  339. rtol = get_default_rtol(output) if is_hip() else 1e-5
  340. assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)