test_flashinfer.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from typing import List, Optional, Tuple
  2. import flashinfer
  3. import pytest
  4. import torch
  5. NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
  6. HEAD_SIZES = [128, 256]
  7. BLOCK_SIZES = [16, 32]
  8. DTYPES = [torch.float16, torch.bfloat16]
  9. NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
  10. def ref_paged_attn(
  11. query: torch.Tensor,
  12. key_cache: torch.Tensor,
  13. value_cache: torch.Tensor,
  14. query_lens: List[int],
  15. kv_lens: List[int],
  16. block_tables: torch.Tensor,
  17. scale: float,
  18. sliding_window: Optional[int] = None,
  19. soft_cap: Optional[float] = None,
  20. ) -> torch.Tensor:
  21. num_seqs = len(query_lens)
  22. block_tables = block_tables.cpu().numpy()
  23. _, block_size, num_kv_heads, head_size = key_cache.shape
  24. outputs: List[torch.Tensor] = []
  25. start_idx = 0
  26. for i in range(num_seqs):
  27. query_len = query_lens[i]
  28. kv_len = kv_lens[i]
  29. q = query[start_idx:start_idx + query_len]
  30. q *= scale
  31. num_kv_blocks = (kv_len + block_size - 1) // block_size
  32. block_indices = block_tables[i, :num_kv_blocks]
  33. k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
  34. k = k[:kv_len]
  35. v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
  36. v = v[:kv_len]
  37. if q.shape[1] != k.shape[1]:
  38. k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
  39. v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
  40. attn = torch.einsum("qhd,khd->hqk", q, k).float()
  41. empty_mask = torch.ones(query_len, kv_len)
  42. mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
  43. if sliding_window is not None:
  44. sliding_window_mask = torch.triu(empty_mask,
  45. diagonal=kv_len -
  46. (query_len + sliding_window) +
  47. 1).bool().logical_not()
  48. mask |= sliding_window_mask
  49. if soft_cap is not None:
  50. attn = soft_cap * torch.tanh(attn / soft_cap)
  51. attn.masked_fill_(mask, float("-inf"))
  52. attn = torch.softmax(attn, dim=-1).to(v.dtype)
  53. out = torch.einsum("hqk,khd->qhd", attn, v)
  54. outputs.append(out)
  55. start_idx += query_len
  56. return torch.cat(outputs, dim=0)
  57. @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
  58. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  59. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  60. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  61. @pytest.mark.parametrize("dtype", DTYPES)
  62. @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
  63. @torch.inference_mode
  64. def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
  65. num_heads: Tuple[int,
  66. int], head_size: int,
  67. dtype: torch.dtype, block_size: int,
  68. soft_cap: Optional[float]) -> None:
  69. torch.set_default_device("cuda")
  70. torch.cuda.manual_seed_all(0)
  71. num_seqs = len(kv_lens)
  72. num_query_heads = num_heads[0]
  73. num_kv_heads = num_heads[1]
  74. assert num_query_heads % num_kv_heads == 0
  75. max_kv_len = max(kv_lens)
  76. scale = head_size**-0.5
  77. query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
  78. key_value_cache = torch.randn(NUM_BLOCKS,
  79. 2,
  80. block_size,
  81. num_kv_heads,
  82. head_size,
  83. dtype=dtype)
  84. key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
  85. value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
  86. max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
  87. block_tables = torch.randint(0,
  88. NUM_BLOCKS,
  89. (num_seqs, max_num_blocks_per_seq),
  90. dtype=torch.int32)
  91. kv_indptr = [0]
  92. kv_indices = []
  93. kv_last_page_lens = []
  94. for i in range(num_seqs):
  95. seq_len = kv_lens[i]
  96. assert seq_len > 0
  97. num_blocks = (seq_len + block_size - 1) // block_size
  98. kv_indices.extend(block_tables[i, :num_blocks])
  99. kv_indptr.append(kv_indptr[-1] + num_blocks)
  100. kv_last_page_len = seq_len % block_size
  101. if kv_last_page_len == 0:
  102. kv_last_page_len = block_size
  103. kv_last_page_lens.append(kv_last_page_len)
  104. kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
  105. kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
  106. kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
  107. workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
  108. wrapper = flashinfer.\
  109. BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
  110. wrapper.begin_forward(kv_indptr,
  111. kv_indices,
  112. kv_last_page_lens,
  113. num_query_heads,
  114. num_kv_heads,
  115. head_size,
  116. block_size,
  117. "NONE",
  118. data_type=dtype)
  119. output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
  120. ref_output = ref_paged_attn(query=query,
  121. key_cache=key_cache,
  122. value_cache=value_cache,
  123. query_lens=[1] * num_seqs,
  124. kv_lens=kv_lens,
  125. block_tables=block_tables,
  126. scale=scale,
  127. soft_cap=soft_cap)
  128. torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
  129. f"{torch.max(torch.abs(output - ref_output))}"
  130. @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
  131. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  132. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  133. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  134. @pytest.mark.parametrize("dtype", DTYPES)
  135. @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
  136. @torch.inference_mode
  137. def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
  138. num_heads: Tuple[int, int],
  139. head_size: int, dtype: torch.dtype,
  140. block_size: int,
  141. soft_cap: Optional[float]) -> None:
  142. torch.set_default_device("cuda")
  143. torch.cuda.manual_seed_all(0)
  144. num_seqs = len(seq_lens)
  145. query_lens = [x[0] for x in seq_lens]
  146. kv_lens = [x[1] for x in seq_lens]
  147. num_query_heads = num_heads[0]
  148. num_kv_heads = num_heads[1]
  149. assert num_query_heads % num_kv_heads == 0
  150. max_kv_len = max(kv_lens)
  151. scale = head_size**-0.5
  152. query = torch.randn(sum(query_lens),
  153. num_query_heads,
  154. head_size,
  155. dtype=dtype)
  156. key_value_cache = torch.randn(NUM_BLOCKS,
  157. 2,
  158. block_size,
  159. num_kv_heads,
  160. head_size,
  161. dtype=dtype)
  162. key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
  163. value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
  164. # Normalize the scale of the key and value caches to mitigate
  165. # numerical instability.
  166. key_cache /= head_size**0.5
  167. value_cache /= head_size**0.5
  168. max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
  169. block_tables = torch.randint(0,
  170. NUM_BLOCKS,
  171. (num_seqs, max_num_blocks_per_seq),
  172. dtype=torch.int32)
  173. qo_indptr = [0]
  174. kv_indptr = [0]
  175. kv_indices = []
  176. kv_last_page_lens = []
  177. for i in range(num_seqs):
  178. seq_len = kv_lens[i]
  179. assert seq_len > 0
  180. num_blocks = (seq_len + block_size - 1) // block_size
  181. kv_indices.extend(block_tables[i, :num_blocks])
  182. kv_indptr.append(kv_indptr[-1] + num_blocks)
  183. kv_last_page_len = seq_len % block_size
  184. if kv_last_page_len == 0:
  185. kv_last_page_len = block_size
  186. kv_last_page_lens.append(kv_last_page_len)
  187. qo_indptr.append(qo_indptr[-1] + query_lens[i])
  188. qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
  189. kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
  190. kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
  191. kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
  192. workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
  193. wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
  194. workspace_buffer, "NHD")
  195. wrapper.begin_forward(
  196. qo_indptr,
  197. kv_indptr,
  198. kv_indices,
  199. kv_last_page_lens,
  200. num_query_heads,
  201. num_kv_heads,
  202. head_size,
  203. block_size,
  204. )
  205. output = wrapper.forward(
  206. query,
  207. key_value_cache,
  208. logits_soft_cap=soft_cap,
  209. )
  210. ref_output = ref_paged_attn(query=query,
  211. key_cache=key_cache,
  212. value_cache=value_cache,
  213. query_lens=query_lens,
  214. kv_lens=kv_lens,
  215. block_tables=block_tables,
  216. scale=scale,
  217. soft_cap=soft_cap)
  218. torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
  219. f"{torch.max(torch.abs(output - ref_output))}"