test_cache.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. import random
  2. from typing import List, Tuple
  3. import pytest
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
  7. COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
  8. DTYPES = [torch.half, torch.bfloat16, torch.float]
  9. NUM_TOKENS = [42] # Arbitrary values for testing
  10. NUM_LAYERS = [1] # Arbitrary values for testing
  11. NUM_HEADS = [8] # Arbitrary values for testing
  12. HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
  13. BLOCK_SIZES = [8, 16, 32]
  14. # Arbitrary values for testing
  15. # don't make it too large. e.g. [1024, 36000] will OOM
  16. NUM_BLOCKS = [1024, 10000]
  17. NUM_MAPPINGS = [256] # Arbitrary values for testing
  18. SEEDS = [0]
  19. CUDA_DEVICES = [
  20. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  21. ]
  22. # We assume fp8 is always enabled for testing.
  23. KV_CACHE_DTYPE = ["auto", "fp8"]
  24. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  25. @pytest.mark.parametrize("num_layers", NUM_LAYERS)
  26. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  27. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  28. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  29. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  30. @pytest.mark.parametrize("dtype", DTYPES)
  31. @pytest.mark.parametrize("seed", SEEDS)
  32. @pytest.mark.parametrize("device", CUDA_DEVICES)
  33. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  34. @torch.inference_mode()
  35. def test_copy_blocks(
  36. kv_cache_factory,
  37. num_mappings: int,
  38. num_layers: int,
  39. num_heads: int,
  40. head_size: int,
  41. block_size: int,
  42. num_blocks: int,
  43. dtype: torch.dtype,
  44. seed: int,
  45. kv_cache_dtype: str,
  46. device: str,
  47. ) -> None:
  48. if kv_cache_dtype == "fp8" and head_size % 16:
  49. pytest.skip()
  50. random.seed(seed)
  51. torch.random.manual_seed(seed)
  52. if torch.cuda.is_available():
  53. torch.cuda.manual_seed(seed)
  54. torch.set_default_device(device)
  55. # Generate random block mappings where each source block is mapped to two
  56. # destination blocks.
  57. assert 2 * num_mappings <= num_blocks
  58. src_blocks = random.sample(range(num_blocks), num_mappings)
  59. remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
  60. dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
  61. block_mapping: List[Tuple[int, int]] = []
  62. for i in range(num_mappings):
  63. src = src_blocks[i]
  64. dst1 = dst_blocks[2 * i]
  65. dst2 = dst_blocks[2 * i + 1]
  66. block_mapping.append((src, dst1))
  67. block_mapping.append((src, dst2))
  68. # Create the KV caches.
  69. key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
  70. num_layers, num_heads,
  71. head_size, kv_cache_dtype,
  72. dtype, seed, device)
  73. # Clone the KV caches.
  74. cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
  75. cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
  76. # Call the copy blocks kernel.
  77. block_mapping_tensor = torch.tensor(block_mapping,
  78. dtype=torch.int64,
  79. device=device).view(-1, 2)
  80. opcheck(torch.ops._C_cache_ops.copy_blocks,
  81. (key_caches, value_caches, block_mapping_tensor),
  82. test_utils=DEFAULT_OPCHECK_TEST_UTILS,
  83. cond=(head_size == HEAD_SIZES[0]))
  84. ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
  85. # Run the reference implementation.
  86. for src, dst in block_mapping:
  87. for cloned_key_cache in cloned_key_caches:
  88. cloned_key_cache[dst].copy_(cloned_key_cache[src])
  89. for cloned_value_cache in cloned_value_caches:
  90. cloned_value_cache[dst].copy_(cloned_value_cache[src])
  91. # Compare the results.
  92. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
  93. torch.testing.assert_close(key_cache, cloned_key_cache)
  94. for value_cache, cloned_value_cache in zip(value_caches,
  95. cloned_value_caches):
  96. torch.testing.assert_close(value_cache, cloned_value_cache)
  97. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  98. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  99. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  100. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  101. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  102. @pytest.mark.parametrize("dtype", DTYPES)
  103. @pytest.mark.parametrize("seed", SEEDS)
  104. @pytest.mark.parametrize("device", CUDA_DEVICES)
  105. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  106. @torch.inference_mode()
  107. def test_reshape_and_cache(
  108. kv_cache_factory,
  109. num_tokens: int,
  110. num_heads: int,
  111. head_size: int,
  112. block_size: int,
  113. num_blocks: int,
  114. dtype: torch.dtype,
  115. seed: int,
  116. device: str,
  117. kv_cache_dtype: str,
  118. ) -> None:
  119. if kv_cache_dtype == "fp8" and head_size % 16:
  120. pytest.skip()
  121. random.seed(seed)
  122. torch.random.manual_seed(seed)
  123. if torch.cuda.is_available():
  124. torch.cuda.manual_seed(seed)
  125. torch.set_default_device(device)
  126. # Create a random slot mapping.
  127. num_slots = block_size * num_blocks
  128. slot_mapping_lst = random.sample(range(num_slots), num_tokens)
  129. slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
  130. qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
  131. _, key, value = qkv.unbind(dim=1)
  132. # Create the KV caches.
  133. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
  134. num_heads, head_size,
  135. kv_cache_dtype, dtype, seed,
  136. device)
  137. key_cache, value_cache = key_caches[0], value_caches[0]
  138. # Clone the KV caches.
  139. if kv_cache_dtype == "fp8":
  140. cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  141. ops.convert_fp8(cloned_key_cache, key_cache)
  142. cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  143. ops.convert_fp8(cloned_value_cache, value_cache)
  144. else:
  145. cloned_key_cache = key_cache.clone()
  146. cloned_value_cache = value_cache.clone()
  147. # Using default kv_scale
  148. k_scale = v_scale = 1.0
  149. # Call the reshape_and_cache kernel.
  150. opcheck(torch.ops._C_cache_ops.reshape_and_cache,
  151. (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
  152. k_scale, v_scale),
  153. cond=(head_size == HEAD_SIZES[0]))
  154. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
  155. kv_cache_dtype, k_scale, v_scale)
  156. if kv_cache_dtype == "fp8":
  157. result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  158. ops.convert_fp8(result_key_cache, key_cache)
  159. result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  160. ops.convert_fp8(result_value_cache, value_cache)
  161. # Run the reference implementation.
  162. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
  163. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
  164. block_indicies_lst = block_indicies.cpu().tolist()
  165. block_offsets = slot_mapping % block_size
  166. block_offsets_lst = block_offsets.cpu().tolist()
  167. for i in range(num_tokens):
  168. block_idx = block_indicies_lst[i]
  169. block_offset = block_offsets_lst[i]
  170. cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
  171. cloned_value_cache[block_idx, :, :, block_offset] = value[i]
  172. if kv_cache_dtype == "fp8":
  173. torch.testing.assert_close(result_key_cache,
  174. cloned_key_cache,
  175. atol=0.001,
  176. rtol=0.1)
  177. torch.testing.assert_close(result_value_cache,
  178. cloned_value_cache,
  179. atol=0.001,
  180. rtol=0.1)
  181. else:
  182. torch.testing.assert_close(key_cache, cloned_key_cache)
  183. torch.testing.assert_close(value_cache, cloned_value_cache)
  184. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  185. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  186. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  187. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  188. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  189. @pytest.mark.parametrize("dtype", DTYPES)
  190. @pytest.mark.parametrize("seed", SEEDS)
  191. @pytest.mark.parametrize("device", CUDA_DEVICES)
  192. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  193. @torch.inference_mode()
  194. def test_reshape_and_cache_flash(
  195. kv_cache_factory_flashinfer,
  196. num_tokens: int,
  197. num_heads: int,
  198. head_size: int,
  199. block_size: int,
  200. num_blocks: int,
  201. dtype: torch.dtype,
  202. seed: int,
  203. device: str,
  204. kv_cache_dtype: str,
  205. ) -> None:
  206. random.seed(seed)
  207. torch.random.manual_seed(seed)
  208. torch.cuda.manual_seed(seed)
  209. torch.set_default_device(device)
  210. # Create a random slot mapping.
  211. num_slots = block_size * num_blocks
  212. slot_mapping_lst = random.sample(range(num_slots), num_tokens)
  213. slot_mapping = torch.tensor(slot_mapping_lst,
  214. dtype=torch.long,
  215. device=device)
  216. qkv = torch.randn(num_tokens,
  217. 3,
  218. num_heads,
  219. head_size,
  220. dtype=dtype,
  221. device=device)
  222. _, key, value = qkv.unbind(dim=1)
  223. # Create the KV caches.
  224. key_caches, value_caches = kv_cache_factory_flashinfer(
  225. num_blocks,
  226. block_size,
  227. 1,
  228. num_heads,
  229. head_size,
  230. kv_cache_dtype,
  231. dtype,
  232. device=device,
  233. )
  234. key_cache, value_cache = key_caches[0].contiguous(
  235. ), value_caches[0].contiguous()
  236. del key_caches
  237. del value_caches
  238. # Clone the KV caches.
  239. if kv_cache_dtype == "fp8":
  240. cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  241. ops.convert_fp8(cloned_key_cache, key_cache)
  242. cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  243. ops.convert_fp8(cloned_value_cache, value_cache)
  244. else:
  245. cloned_key_cache = key_cache.clone()
  246. cloned_value_cache = value_cache.clone()
  247. # Using default kv_scale
  248. k_scale = v_scale = 1.0
  249. # Call the reshape_and_cache kernel.
  250. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
  251. (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
  252. k_scale, v_scale),
  253. cond=(head_size == HEAD_SIZES[0]))
  254. ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
  255. slot_mapping, kv_cache_dtype, k_scale, v_scale)
  256. if kv_cache_dtype == "fp8":
  257. result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  258. ops.convert_fp8(result_key_cache, key_cache)
  259. result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  260. ops.convert_fp8(result_value_cache, value_cache)
  261. # Run the reference implementation.
  262. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
  263. block_indicies_lst = block_indicies.cpu().tolist()
  264. block_offsets = slot_mapping % block_size
  265. block_offsets_lst = block_offsets.cpu().tolist()
  266. for i in range(num_tokens):
  267. block_idx = block_indicies_lst[i]
  268. block_offset = block_offsets_lst[i]
  269. cloned_key_cache[block_idx, block_offset, :, :] = key[i]
  270. cloned_value_cache[block_idx, block_offset, :, :] = value[i]
  271. if kv_cache_dtype == "fp8":
  272. torch.testing.assert_close(result_key_cache,
  273. cloned_key_cache,
  274. atol=0.001,
  275. rtol=0.1)
  276. torch.testing.assert_close(result_value_cache,
  277. cloned_value_cache,
  278. atol=0.001,
  279. rtol=0.1)
  280. else:
  281. torch.testing.assert_close(key_cache, cloned_key_cache)
  282. torch.testing.assert_close(value_cache, cloned_value_cache)
  283. @pytest.mark.parametrize("direction", COPYING_DIRECTION)
  284. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  285. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  286. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  287. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  288. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  289. @pytest.mark.parametrize("dtype", DTYPES)
  290. @pytest.mark.parametrize("seed", SEEDS)
  291. @pytest.mark.parametrize("device", CUDA_DEVICES)
  292. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  293. @torch.inference_mode()
  294. def test_swap_blocks(
  295. kv_cache_factory,
  296. direction: Tuple[str, str],
  297. num_mappings: int,
  298. num_heads: int,
  299. head_size: int,
  300. block_size: int,
  301. num_blocks: int,
  302. dtype: torch.dtype,
  303. seed: int,
  304. device: str,
  305. kv_cache_dtype: str,
  306. ) -> None:
  307. if kv_cache_dtype == "fp8" and "cpu" in direction:
  308. pytest.skip()
  309. if kv_cache_dtype == "fp8" and head_size % 16:
  310. pytest.skip()
  311. random.seed(seed)
  312. torch.random.manual_seed(seed)
  313. if torch.cuda.is_available():
  314. torch.cuda.manual_seed(seed)
  315. src_device = device if direction[0] == "cuda" else 'cpu'
  316. dst_device = device if direction[1] == "cuda" else 'cpu'
  317. src_blocks = random.sample(range(num_blocks), num_mappings)
  318. # For the same device, mapping must not overlap
  319. if src_device == dst_device:
  320. remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
  321. dst_blocks = random.sample(remaining_blocks, num_mappings)
  322. else:
  323. dst_blocks = random.sample(range(num_blocks), num_mappings)
  324. block_mapping = list(zip(src_blocks, dst_blocks))
  325. block_mapping_tensor = torch.tensor(block_mapping,
  326. dtype=torch.int64,
  327. device="cpu").view(-1, 2)
  328. # Create the KV caches on the first device.
  329. src_key_caches, src_value_caches = kv_cache_factory(
  330. num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
  331. seed, src_device)
  332. # Create the KV caches on the second device.
  333. dist_key_caches, dist_value_caches = kv_cache_factory(
  334. num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
  335. seed, dst_device)
  336. src_key_caches_clone = src_key_caches[0].clone()
  337. src_value_caches_clone = src_value_caches[0].clone()
  338. # Call the swap_blocks kernel.
  339. do_opcheck = (head_size == HEAD_SIZES[0])
  340. opcheck(torch.ops._C_cache_ops.swap_blocks,
  341. (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
  342. cond=do_opcheck)
  343. opcheck(torch.ops._C_cache_ops.swap_blocks,
  344. (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
  345. cond=do_opcheck)
  346. ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
  347. block_mapping_tensor)
  348. ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
  349. block_mapping_tensor)
  350. for src, dst in block_mapping:
  351. torch.testing.assert_close(src_key_caches_clone[src].cpu(),
  352. dist_key_caches[0][dst].cpu())
  353. torch.testing.assert_close(src_value_caches_clone[src].cpu(),
  354. dist_value_caches[0][dst].cpu())
  355. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  356. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  357. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  358. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  359. @pytest.mark.parametrize("dtype", DTYPES)
  360. @pytest.mark.parametrize("seed", SEEDS)
  361. @pytest.mark.parametrize("device", CUDA_DEVICES)
  362. @torch.inference_mode()
  363. def test_fp8_e4m3_conversion(
  364. num_heads: int,
  365. head_size: int,
  366. block_size: int,
  367. num_blocks: int,
  368. dtype: torch.dtype,
  369. seed: int,
  370. device: str,
  371. ) -> None:
  372. random.seed(seed)
  373. torch.random.manual_seed(seed)
  374. torch.cuda.manual_seed(seed)
  375. low = -224.0
  376. high = 224.0
  377. shape = (num_blocks, num_heads, head_size, block_size)
  378. cache = torch.empty(shape, dtype=dtype, device=device)
  379. cache.uniform_(low, high)
  380. cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
  381. ops.convert_fp8(cache_fp8, cache)
  382. converted_cache = torch.empty_like(cache)
  383. ops.convert_fp8(converted_cache, cache_fp8)
  384. torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)