test_flash_attn.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from typing import List, Optional, Tuple
  2. import pytest
  3. import torch
  4. import aphrodite.attention.backends.flash_attn # noqa: F401
  5. NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
  6. HEAD_SIZES = [128, 256]
  7. BLOCK_SIZES = [16, 32]
  8. DTYPES = [torch.float16, torch.bfloat16]
  9. # one value large enough to test overflow in index calculation.
  10. # one value small enough to test the schema op check
  11. NUM_BLOCKS = [32768, 2048]
  12. def ref_paged_attn(
  13. query: torch.Tensor,
  14. key_cache: torch.Tensor,
  15. value_cache: torch.Tensor,
  16. query_lens: List[int],
  17. kv_lens: List[int],
  18. block_tables: torch.Tensor,
  19. scale: float,
  20. sliding_window: Optional[int] = None,
  21. soft_cap: Optional[float] = None,
  22. ) -> torch.Tensor:
  23. num_seqs = len(query_lens)
  24. block_tables = block_tables.cpu().numpy()
  25. _, block_size, num_kv_heads, head_size = key_cache.shape
  26. outputs: List[torch.Tensor] = []
  27. start_idx = 0
  28. for i in range(num_seqs):
  29. query_len = query_lens[i]
  30. kv_len = kv_lens[i]
  31. q = query[start_idx:start_idx + query_len]
  32. q *= scale
  33. num_kv_blocks = (kv_len + block_size - 1) // block_size
  34. block_indices = block_tables[i, :num_kv_blocks]
  35. k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
  36. k = k[:kv_len]
  37. v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
  38. v = v[:kv_len]
  39. if q.shape[1] != k.shape[1]:
  40. k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
  41. v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
  42. attn = torch.einsum("qhd,khd->hqk", q, k).float()
  43. empty_mask = torch.ones(query_len, kv_len)
  44. mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
  45. if sliding_window is not None:
  46. sliding_window_mask = torch.triu(empty_mask,
  47. diagonal=kv_len -
  48. (query_len + sliding_window) +
  49. 1).bool().logical_not()
  50. mask |= sliding_window_mask
  51. if soft_cap is not None:
  52. attn = soft_cap * torch.tanh(attn / soft_cap)
  53. attn.masked_fill_(mask, float("-inf"))
  54. attn = torch.softmax(attn, dim=-1).to(v.dtype)
  55. out = torch.einsum("hqk,khd->qhd", attn, v)
  56. outputs.append(out)
  57. start_idx += query_len
  58. return torch.cat(outputs, dim=0)
  59. @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
  60. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  61. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  62. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  63. @pytest.mark.parametrize("dtype", DTYPES)
  64. @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
  65. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  66. @torch.inference_mode()
  67. def test_flash_attn_with_paged_kv(
  68. kv_lens: List[int],
  69. num_heads: Tuple[int, int],
  70. head_size: int,
  71. dtype: torch.dtype,
  72. block_size: int,
  73. soft_cap: Optional[float],
  74. num_blocks: int,
  75. ) -> None:
  76. torch.set_default_device("cuda")
  77. torch.cuda.manual_seed_all(0)
  78. num_seqs = len(kv_lens)
  79. num_query_heads = num_heads[0]
  80. num_kv_heads = num_heads[1]
  81. assert num_query_heads % num_kv_heads == 0
  82. max_kv_len = max(kv_lens)
  83. scale = head_size**-0.5
  84. query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
  85. key_cache = torch.randn(num_blocks,
  86. block_size,
  87. num_kv_heads,
  88. head_size,
  89. dtype=dtype)
  90. value_cache = torch.randn_like(key_cache)
  91. kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
  92. max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
  93. block_tables = torch.randint(0,
  94. num_blocks,
  95. (num_seqs, max_num_blocks_per_seq),
  96. dtype=torch.int32)
  97. output = torch.ops.aphrodite.flash_attn_with_kvcache(
  98. decode_query=query.unsqueeze(1),
  99. key_cache=key_cache,
  100. value_cache=value_cache,
  101. softmax_scale=scale,
  102. causal=True,
  103. block_table=block_tables,
  104. cache_seqlens=kv_lens_tensor,
  105. softcap=soft_cap if soft_cap is not None else 0,
  106. ).squeeze(1)
  107. if num_blocks <= 2048:
  108. test_utils = ["test_faketensor", "test_schema"]
  109. else:
  110. test_utils = ["test_faketensor"]
  111. torch.library.opcheck(torch.ops.aphrodite.flash_attn_with_kvcache,
  112. args=tuple(),
  113. kwargs=dict(
  114. decode_query=query.unsqueeze(1),
  115. key_cache=key_cache,
  116. value_cache=value_cache,
  117. softmax_scale=scale,
  118. causal=True,
  119. block_table=block_tables,
  120. cache_seqlens=kv_lens_tensor,
  121. softcap=soft_cap if soft_cap is not None else 0,
  122. ),
  123. test_utils=test_utils)
  124. ref_output = ref_paged_attn(
  125. query=query,
  126. key_cache=key_cache,
  127. value_cache=value_cache,
  128. query_lens=[1] * num_seqs,
  129. kv_lens=kv_lens,
  130. block_tables=block_tables,
  131. scale=scale,
  132. soft_cap=soft_cap,
  133. )
  134. torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
  135. f"{torch.max(torch.abs(output - ref_output))}"
  136. @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
  137. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  138. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  139. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  140. @pytest.mark.parametrize("sliding_window", [None])
  141. @pytest.mark.parametrize("dtype", DTYPES)
  142. @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
  143. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  144. @torch.inference_mode()
  145. def test_varlen_with_paged_kv(
  146. seq_lens: List[Tuple[int, int]],
  147. num_heads: Tuple[int, int],
  148. head_size: int,
  149. sliding_window: Optional[int],
  150. dtype: torch.dtype,
  151. block_size: int,
  152. soft_cap: Optional[float],
  153. num_blocks: int,
  154. ) -> None:
  155. torch.set_default_device("cuda")
  156. torch.cuda.manual_seed_all(0)
  157. num_seqs = len(seq_lens)
  158. query_lens = [x[0] for x in seq_lens]
  159. kv_lens = [x[1] for x in seq_lens]
  160. num_query_heads = num_heads[0]
  161. num_kv_heads = num_heads[1]
  162. assert num_query_heads % num_kv_heads == 0
  163. max_query_len = max(query_lens)
  164. max_kv_len = max(kv_lens)
  165. window_size = ((sliding_window,
  166. sliding_window) if sliding_window is not None else
  167. (-1, -1))
  168. scale = head_size**-0.5
  169. query = torch.randn(sum(query_lens),
  170. num_query_heads,
  171. head_size,
  172. dtype=dtype)
  173. key_cache = torch.randn(num_blocks,
  174. block_size,
  175. num_kv_heads,
  176. head_size,
  177. dtype=dtype)
  178. value_cache = torch.randn_like(key_cache)
  179. cu_query_lens = torch.tensor([0] + query_lens,
  180. dtype=torch.int32).cumsum(dim=0,
  181. dtype=torch.int32)
  182. cu_kv_lens = torch.tensor([0] + kv_lens,
  183. dtype=torch.int32).cumsum(dim=0,
  184. dtype=torch.int32)
  185. max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
  186. block_tables = torch.randint(0,
  187. num_blocks,
  188. (num_seqs, max_num_blocks_per_seq),
  189. dtype=torch.int32)
  190. output = torch.ops.aphrodite.flash_attn_varlen_func(
  191. q=query,
  192. k=key_cache,
  193. v=value_cache,
  194. cu_seqlens_q=cu_query_lens,
  195. cu_seqlens_k=cu_kv_lens,
  196. max_seqlen_q=max_query_len,
  197. max_seqlen_k=max_kv_len,
  198. softmax_scale=scale,
  199. causal=True,
  200. window_size=window_size,
  201. block_table=block_tables,
  202. softcap=soft_cap if soft_cap is not None else 0,
  203. )
  204. if num_blocks <= 2048:
  205. test_utils = ["test_faketensor", "test_schema"]
  206. else:
  207. test_utils = ["test_faketensor"]
  208. torch.library.opcheck(torch.ops.aphrodite.flash_attn_varlen_func,
  209. args=tuple(),
  210. kwargs=dict(
  211. q=query,
  212. k=key_cache,
  213. v=value_cache,
  214. cu_seqlens_q=cu_query_lens,
  215. cu_seqlens_k=cu_kv_lens,
  216. max_seqlen_q=max_query_len,
  217. max_seqlen_k=max_kv_len,
  218. softmax_scale=scale,
  219. causal=True,
  220. window_size=window_size,
  221. block_table=block_tables,
  222. softcap=soft_cap if soft_cap is not None else 0,
  223. ),
  224. test_utils=test_utils)
  225. ref_output = ref_paged_attn(
  226. query=query,
  227. key_cache=key_cache,
  228. value_cache=value_cache,
  229. query_lens=query_lens,
  230. kv_lens=kv_lens,
  231. block_tables=block_tables,
  232. scale=scale,
  233. sliding_window=sliding_window,
  234. soft_cap=soft_cap,
  235. )
  236. torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
  237. f"{torch.max(torch.abs(output - ref_output))}"