test_attention.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import random
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
  7. from aphrodite._C import ops
  8. from aphrodite.common.utils import get_max_shared_memory_bytes
  9. FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
  10. # This will change depending on the compute capability.
  11. # - 512 as a buffer
  12. MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
  13. NUM_BLOCKS = 40000 # Arbitrary values for testing
  14. PARTITION_SIZE = 512
  15. DTYPES = [torch.half, torch.bfloat16, torch.float]
  16. NUM_GEN_SEQS = [7] # Arbitrary values for testing
  17. NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
  18. NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
  19. HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  20. BLOCK_SIZES = [16, 32]
  21. USE_ALIBI = [False, True]
  22. SEEDS = [0]
  23. DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
  24. def ref_masked_attention(
  25. query: torch.Tensor,
  26. key: torch.Tensor,
  27. value: torch.Tensor,
  28. scale: float,
  29. attn_mask: Optional[torch.Tensor] = None,
  30. ) -> torch.Tensor:
  31. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  32. if attn_mask is not None:
  33. attn_weights = attn_weights + attn_mask.float()
  34. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  35. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  36. return out
  37. def ref_single_query_cached_kv_attention(
  38. output: torch.Tensor,
  39. query: torch.Tensor,
  40. num_queries_per_kv: int,
  41. key_cache: torch.Tensor,
  42. value_cache: torch.Tensor,
  43. block_tables: torch.Tensor,
  44. context_lens: torch.Tensor,
  45. scale: float,
  46. alibi_slopes: Optional[torch.Tensor],
  47. ) -> None:
  48. num_query_heads = query.shape[1]
  49. num_kv_heads = value_cache.shape[1]
  50. head_size = value_cache.shape[2]
  51. block_size = value_cache.shape[3]
  52. num_seqs = query.shape[0]
  53. block_tables = block_tables.cpu().tolist()
  54. context_lens = context_lens.cpu().tolist()
  55. for i in range(num_seqs):
  56. q = query[i].unsqueeze(0)
  57. block_table = block_tables[i]
  58. context_len = int(context_lens[i])
  59. keys = []
  60. values = []
  61. for j in range(context_len):
  62. block_number = int(block_table[j // block_size])
  63. block_offset = j % block_size
  64. k = key_cache[block_number, :, :, block_offset, :]
  65. k = k.reshape(num_kv_heads, head_size)
  66. keys.append(k)
  67. v = value_cache[block_number, :, :, block_offset]
  68. values.append(v)
  69. keys = torch.stack(keys, dim=0)
  70. values = torch.stack(values, dim=0)
  71. if num_queries_per_kv > 1:
  72. # Handle MQA and GQA
  73. keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
  74. values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
  75. alibi_bias = None
  76. if alibi_slopes is not None:
  77. # Create the ALiBi bias used in the paged attention kernel.
  78. position_ids = torch.arange(context_len, device=query.device).int()
  79. alibi_bias = (position_ids - context_len + 1).float()
  80. alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
  81. 1, 1, -1)
  82. out = ref_masked_attention(q, keys, values, scale, alibi_bias)
  83. out = out.view(num_query_heads, head_size)
  84. output[i].copy_(out, non_blocking=True)
  85. @pytest.mark.parametrize("version", ["v1", "v2"])
  86. @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
  87. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  88. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  89. @pytest.mark.parametrize("use_alibi", USE_ALIBI)
  90. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  91. @pytest.mark.parametrize("dtype", DTYPES)
  92. @pytest.mark.parametrize("seed", SEEDS)
  93. @pytest.mark.parametrize("device", DEVICES)
  94. def test_paged_attention(
  95. kv_cache_factory,
  96. version: str,
  97. num_seqs: int,
  98. num_heads: Tuple[int, int],
  99. head_size: int,
  100. use_alibi: bool,
  101. block_size: int,
  102. dtype: torch.dtype,
  103. seed: int,
  104. device: int,
  105. ) -> None:
  106. random.seed(seed)
  107. torch.random.manual_seed(seed)
  108. torch.cuda.manual_seed(seed)
  109. gpu_id = f"cuda:{device}"
  110. scale = float(1.0 / (head_size**0.5))
  111. num_query_heads, num_kv_heads = num_heads
  112. query = torch.empty(num_seqs,
  113. num_query_heads,
  114. head_size,
  115. dtype=dtype,
  116. device=gpu_id)
  117. query.uniform_(-scale, scale)
  118. assert num_query_heads % num_kv_heads == 0
  119. num_queries_per_kv = num_query_heads // num_kv_heads
  120. alibi_slopes = None
  121. if use_alibi:
  122. alibi_slopes = torch.randn(num_query_heads,
  123. dtype=torch.float,
  124. device=gpu_id)
  125. context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
  126. context_lens[-1] = MAX_SEQ_LEN
  127. max_context_len = max(context_lens)
  128. context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
  129. # Create the block tables.
  130. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
  131. block_tables = []
  132. for _ in range(num_seqs):
  133. block_table = [
  134. random.randint(0, NUM_BLOCKS - 1)
  135. for _ in range(max_num_blocks_per_seq)
  136. ]
  137. block_tables.append(block_table)
  138. block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
  139. # Create the KV caches.
  140. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
  141. num_kv_heads, head_size, dtype,
  142. seed, gpu_id)
  143. key_cache, value_cache = key_caches[0], value_caches[0]
  144. # Call the paged attention kernel.
  145. output = torch.empty_like(query)
  146. if version == "v1":
  147. ops.paged_attention_v1(
  148. output,
  149. query,
  150. key_cache,
  151. value_cache,
  152. num_kv_heads,
  153. scale,
  154. block_tables,
  155. context_lens,
  156. block_size,
  157. max_context_len,
  158. alibi_slopes,
  159. )
  160. elif version == "v2":
  161. num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
  162. PARTITION_SIZE)
  163. assert PARTITION_SIZE % block_size == 0
  164. num_seqs, num_heads, head_size = output.shape
  165. tmp_output = torch.empty(
  166. size=(num_seqs, num_heads, num_partitions, head_size),
  167. dtype=output.dtype,
  168. device=output.device,
  169. )
  170. exp_sums = torch.empty(
  171. size=(num_seqs, num_heads, num_partitions),
  172. dtype=torch.float32,
  173. device=output.device,
  174. )
  175. max_logits = torch.empty_like(exp_sums)
  176. ops.paged_attention_v2(
  177. output,
  178. exp_sums,
  179. max_logits,
  180. tmp_output,
  181. query,
  182. key_cache,
  183. value_cache,
  184. num_kv_heads,
  185. scale,
  186. block_tables,
  187. context_lens,
  188. block_size,
  189. max_context_len,
  190. alibi_slopes,
  191. )
  192. else:
  193. raise AssertionError(f"Unknown version: {version}")
  194. # Run the reference implementation.
  195. ref_output = torch.empty_like(query)
  196. ref_single_query_cached_kv_attention(
  197. ref_output,
  198. query,
  199. num_queries_per_kv,
  200. key_cache,
  201. value_cache,
  202. block_tables,
  203. context_lens,
  204. scale,
  205. alibi_slopes,
  206. )
  207. # NOTE: Due to the kernel-level differences in the two
  208. # implementations, there is a small numerical difference in the two
  209. # outputs. Thus, we use a relaxed tolerance for the test.
  210. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
  211. def ref_multi_query_kv_attention(
  212. cu_seq_lens: List[int],
  213. query: torch.Tensor,
  214. key: torch.Tensor,
  215. value: torch.Tensor,
  216. scale: float,
  217. dtype: torch.dtype,
  218. ) -> torch.Tensor:
  219. num_seqs = len(cu_seq_lens) - 1
  220. ref_outputs = []
  221. for i in range(num_seqs):
  222. start_idx = cu_seq_lens[i]
  223. end_idx = cu_seq_lens[i + 1]
  224. seq_len = end_idx - start_idx
  225. # Create attention mask.
  226. attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
  227. diagonal=1)
  228. attn_mask = attn_mask * torch.finfo(dtype).min
  229. attn_mask = attn_mask.to(dtype=dtype, device=query.device)
  230. ref_output = ref_masked_attention(
  231. query[start_idx:end_idx],
  232. key[start_idx:end_idx],
  233. value[start_idx:end_idx],
  234. scale,
  235. attn_mask=attn_mask,
  236. )
  237. ref_outputs.append(ref_output)
  238. ref_output = torch.cat(ref_outputs, dim=0)
  239. return ref_output
  240. # TODO: Add tests for USE_ALIBI=True.
  241. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
  242. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  243. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  244. @pytest.mark.parametrize("dtype", DTYPES)
  245. @pytest.mark.parametrize("seed", SEEDS)
  246. @pytest.mark.parametrize("device", DEVICES)
  247. @torch.inference_mode()
  248. def test_multi_query_kv_attention(
  249. num_seqs: int,
  250. num_heads: Tuple[int, int],
  251. head_size: int,
  252. dtype: torch.dtype,
  253. seed: int,
  254. device: int,
  255. ) -> None:
  256. random.seed(seed)
  257. torch.random.manual_seed(seed)
  258. torch.cuda.manual_seed(seed)
  259. gpu_id = f"cuda:{device}"
  260. # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
  261. # As the xformers library is already tested with its own tests, we can use
  262. # a smaller MAX_SEQ_LEN here.
  263. max_len = min(MAX_SEQ_LEN, 4096)
  264. seq_lens = random.sample(range(1, max_len), num_seqs)
  265. num_tokens = sum(seq_lens)
  266. scale = float(1.0 / (head_size**0.5))
  267. num_query_heads, num_kv_heads = num_heads
  268. qkv = torch.empty(num_tokens,
  269. num_query_heads + 2 * num_kv_heads,
  270. head_size,
  271. dtype=dtype,
  272. device=gpu_id)
  273. qkv.uniform_(-scale, scale)
  274. query, key, value = qkv.split(
  275. [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
  276. num_queries_per_kv = num_query_heads // num_kv_heads
  277. if num_queries_per_kv > 1:
  278. # Handle MQA and GQA
  279. key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
  280. value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
  281. attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
  282. output = xops.memory_efficient_attention_forward(
  283. query.unsqueeze(0),
  284. key.unsqueeze(0),
  285. value.unsqueeze(0),
  286. attn_bias=attn_bias,
  287. p=0.0,
  288. scale=scale,
  289. )
  290. output = output.squeeze(0)
  291. cu_seq_lens = [0]
  292. for seq_len in seq_lens:
  293. cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
  294. ref_output = ref_multi_query_kv_attention(
  295. cu_seq_lens,
  296. query,
  297. key,
  298. value,
  299. scale,
  300. dtype,
  301. )
  302. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)