test_flashinfer.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from typing import List, Optional, Tuple
  2. import flashinfer
  3. import pytest
  4. import torch
  5. NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
  6. HEAD_SIZES = [128, 256]
  7. BLOCK_SIZES = [16, 32]
  8. DTYPES = [torch.float16, torch.bfloat16]
  9. NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
  10. def ref_paged_attn(
  11. query: torch.Tensor,
  12. key_cache: torch.Tensor,
  13. value_cache: torch.Tensor,
  14. query_lens: List[int],
  15. kv_lens: List[int],
  16. block_tables: torch.Tensor,
  17. scale: float,
  18. sliding_window: Optional[int] = None,
  19. soft_cap: Optional[float] = None,
  20. ) -> torch.Tensor:
  21. num_seqs = len(query_lens)
  22. block_tables = block_tables.cpu().numpy()
  23. _, block_size, num_kv_heads, head_size = key_cache.shape
  24. outputs: List[torch.Tensor] = []
  25. start_idx = 0
  26. for i in range(num_seqs):
  27. query_len = query_lens[i]
  28. kv_len = kv_lens[i]
  29. q = query[start_idx:start_idx + query_len]
  30. q *= scale
  31. num_kv_blocks = (kv_len + block_size - 1) // block_size
  32. block_indices = block_tables[i, :num_kv_blocks]
  33. k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
  34. k = k[:kv_len]
  35. v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
  36. v = v[:kv_len]
  37. if q.shape[1] != k.shape[1]:
  38. k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
  39. v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
  40. attn = torch.einsum("qhd,khd->hqk", q, k).float()
  41. empty_mask = torch.ones(query_len, kv_len)
  42. mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
  43. if sliding_window is not None:
  44. sliding_window_mask = torch.triu(empty_mask,
  45. diagonal=kv_len -
  46. (query_len + sliding_window) +
  47. 1).bool().logical_not()
  48. mask |= sliding_window_mask
  49. if soft_cap is not None:
  50. attn = soft_cap * torch.tanh(attn / soft_cap)
  51. attn.masked_fill_(mask, float("-inf"))
  52. attn = torch.softmax(attn, dim=-1).to(v.dtype)
  53. out = torch.einsum("hqk,khd->qhd", attn, v)
  54. outputs.append(out)
  55. start_idx += query_len
  56. return torch.cat(outputs, dim=0)
  57. @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
  58. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  59. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  60. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  61. @pytest.mark.parametrize("dtype", DTYPES)
  62. @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
  63. @torch.inference_mode
  64. def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
  65. num_heads: Tuple[int,
  66. int], head_size: int,
  67. dtype: torch.dtype, block_size: int,
  68. soft_cap: Optional[float]) -> None:
  69. torch.set_default_device("cuda")
  70. torch.cuda.manual_seed_all(0)
  71. num_seqs = len(kv_lens)
  72. num_query_heads = num_heads[0]
  73. num_kv_heads = num_heads[1]
  74. assert num_query_heads % num_kv_heads == 0
  75. max_kv_len = max(kv_lens)
  76. scale = head_size**-0.5
  77. query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
  78. key_value_cache = torch.randn(NUM_BLOCKS,
  79. 2,
  80. block_size,
  81. num_kv_heads,
  82. head_size,
  83. dtype=dtype)
  84. key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
  85. value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
  86. max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
  87. block_tables = torch.randint(0,
  88. NUM_BLOCKS,
  89. (num_seqs, max_num_blocks_per_seq),
  90. dtype=torch.int32)
  91. kv_indptr = [0]
  92. kv_indices = []
  93. kv_last_page_lens = []
  94. for i in range(num_seqs):
  95. seq_len = kv_lens[i]
  96. assert seq_len > 0
  97. num_blocks = (seq_len + block_size - 1) // block_size
  98. kv_indices.extend(block_tables[i, :num_blocks])
  99. kv_indptr.append(kv_indptr[-1] + num_blocks)
  100. kv_last_page_len = seq_len % block_size
  101. if kv_last_page_len == 0:
  102. kv_last_page_len = block_size
  103. kv_last_page_lens.append(kv_last_page_len)
  104. kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
  105. kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
  106. kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
  107. workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
  108. wrapper = flashinfer.\
  109. BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
  110. use_tensor_cores=(
  111. (num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
  112. )
  113. wrapper.begin_forward(kv_indptr,
  114. kv_indices,
  115. kv_last_page_lens,
  116. num_query_heads,
  117. num_kv_heads,
  118. head_size,
  119. block_size,
  120. "NONE",
  121. data_type=dtype)
  122. output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
  123. ref_output = ref_paged_attn(query=query,
  124. key_cache=key_cache,
  125. value_cache=value_cache,
  126. query_lens=[1] * num_seqs,
  127. kv_lens=kv_lens,
  128. block_tables=block_tables,
  129. scale=scale,
  130. soft_cap=soft_cap)
  131. torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
  132. f"{torch.max(torch.abs(output - ref_output))}"
  133. @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
  134. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  135. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  136. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  137. @pytest.mark.parametrize("dtype", DTYPES)
  138. @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
  139. @torch.inference_mode
  140. def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
  141. num_heads: Tuple[int, int],
  142. head_size: int, dtype: torch.dtype,
  143. block_size: int,
  144. soft_cap: Optional[float]) -> None:
  145. torch.set_default_device("cuda")
  146. torch.cuda.manual_seed_all(0)
  147. num_seqs = len(seq_lens)
  148. query_lens = [x[0] for x in seq_lens]
  149. kv_lens = [x[1] for x in seq_lens]
  150. num_query_heads = num_heads[0]
  151. num_kv_heads = num_heads[1]
  152. assert num_query_heads % num_kv_heads == 0
  153. max_kv_len = max(kv_lens)
  154. scale = head_size**-0.5
  155. query = torch.randn(sum(query_lens),
  156. num_query_heads,
  157. head_size,
  158. dtype=dtype)
  159. key_value_cache = torch.randn(NUM_BLOCKS,
  160. 2,
  161. block_size,
  162. num_kv_heads,
  163. head_size,
  164. dtype=dtype)
  165. key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
  166. value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
  167. # Normalize the scale of the key and value caches to mitigate
  168. # numerical instability.
  169. key_cache /= head_size**0.5
  170. value_cache /= head_size**0.5
  171. max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
  172. block_tables = torch.randint(0,
  173. NUM_BLOCKS,
  174. (num_seqs, max_num_blocks_per_seq),
  175. dtype=torch.int32)
  176. qo_indptr = [0]
  177. kv_indptr = [0]
  178. kv_indices = []
  179. kv_last_page_lens = []
  180. for i in range(num_seqs):
  181. seq_len = kv_lens[i]
  182. assert seq_len > 0
  183. num_blocks = (seq_len + block_size - 1) // block_size
  184. kv_indices.extend(block_tables[i, :num_blocks])
  185. kv_indptr.append(kv_indptr[-1] + num_blocks)
  186. kv_last_page_len = seq_len % block_size
  187. if kv_last_page_len == 0:
  188. kv_last_page_len = block_size
  189. kv_last_page_lens.append(kv_last_page_len)
  190. qo_indptr.append(qo_indptr[-1] + query_lens[i])
  191. qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
  192. kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
  193. kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
  194. kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
  195. workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
  196. wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
  197. workspace_buffer, "NHD")
  198. wrapper.begin_forward(
  199. qo_indptr,
  200. kv_indptr,
  201. kv_indices,
  202. kv_last_page_lens,
  203. num_query_heads,
  204. num_kv_heads,
  205. head_size,
  206. block_size,
  207. )
  208. output = wrapper.forward(
  209. query,
  210. key_value_cache,
  211. logits_soft_cap=soft_cap,
  212. )
  213. ref_output = ref_paged_attn(query=query,
  214. key_cache=key_cache,
  215. value_cache=value_cache,
  216. query_lens=query_lens,
  217. kv_lens=kv_lens,
  218. block_tables=block_tables,
  219. scale=scale,
  220. soft_cap=soft_cap)
  221. torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
  222. f"{torch.max(torch.abs(output - ref_output))}"