test_cache.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import random
  2. from typing import List, Tuple
  3. import pytest
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
  7. DTYPES = [torch.half, torch.bfloat16, torch.float]
  8. NUM_TOKENS = [42] # Arbitrary values for testing
  9. NUM_LAYERS = [1] # Arbitrary values for testing
  10. NUM_HEADS = [8] # Arbitrary values for testing
  11. HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
  12. BLOCK_SIZES = [8, 16, 32]
  13. # Arbitrary values for testing
  14. # don't make it too large. e.g. [1024, 36000] will OOM
  15. NUM_BLOCKS = [1024, 10000]
  16. NUM_MAPPINGS = [256] # Arbitrary values for testing
  17. SEEDS = [0]
  18. CUDA_DEVICES = [
  19. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  20. ]
  21. # We assume fp8 is always enabled for testing.
  22. KV_CACHE_DTYPE = ["auto", "fp8"]
  23. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  24. @pytest.mark.parametrize("num_layers", NUM_LAYERS)
  25. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  26. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  27. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  28. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  29. @pytest.mark.parametrize("dtype", DTYPES)
  30. @pytest.mark.parametrize("seed", SEEDS)
  31. @pytest.mark.parametrize("device", CUDA_DEVICES)
  32. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  33. @torch.inference_mode()
  34. def test_copy_blocks(
  35. kv_cache_factory,
  36. num_mappings: int,
  37. num_layers: int,
  38. num_heads: int,
  39. head_size: int,
  40. block_size: int,
  41. num_blocks: int,
  42. dtype: torch.dtype,
  43. seed: int,
  44. kv_cache_dtype: str,
  45. device: str,
  46. ) -> None:
  47. if kv_cache_dtype == "fp8" and head_size % 16:
  48. pytest.skip()
  49. random.seed(seed)
  50. torch.random.manual_seed(seed)
  51. if torch.cuda.is_available():
  52. torch.cuda.manual_seed(seed)
  53. torch.set_default_device(device)
  54. # Generate random block mappings where each source block is mapped to two
  55. # destination blocks.
  56. assert 2 * num_mappings <= num_blocks
  57. src_blocks = random.sample(range(num_blocks), num_mappings)
  58. remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
  59. dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
  60. block_mapping: List[Tuple[int, int]] = []
  61. for i in range(num_mappings):
  62. src = src_blocks[i]
  63. dst1 = dst_blocks[2 * i]
  64. dst2 = dst_blocks[2 * i + 1]
  65. block_mapping.append((src, dst1))
  66. block_mapping.append((src, dst2))
  67. # Create the KV caches.
  68. key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
  69. num_layers, num_heads,
  70. head_size, kv_cache_dtype,
  71. dtype, seed, device)
  72. # Clone the KV caches.
  73. cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
  74. cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
  75. # Call the copy blocks kernel.
  76. block_mapping_tensor = torch.tensor(block_mapping,
  77. dtype=torch.int64,
  78. device=device).view(-1, 2)
  79. ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
  80. # Run the reference implementation.
  81. for src, dst in block_mapping:
  82. for cloned_key_cache in cloned_key_caches:
  83. cloned_key_cache[dst].copy_(cloned_key_cache[src])
  84. for cloned_value_cache in cloned_value_caches:
  85. cloned_value_cache[dst].copy_(cloned_value_cache[src])
  86. # Compare the results.
  87. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
  88. torch.testing.assert_close(key_cache, cloned_key_cache)
  89. for value_cache, cloned_value_cache in zip(value_caches,
  90. cloned_value_caches):
  91. torch.testing.assert_close(value_cache, cloned_value_cache)
  92. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  93. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  94. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  95. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  96. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  97. @pytest.mark.parametrize("dtype", DTYPES)
  98. @pytest.mark.parametrize("seed", SEEDS)
  99. @pytest.mark.parametrize("device", CUDA_DEVICES)
  100. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  101. @torch.inference_mode()
  102. def test_reshape_and_cache(
  103. kv_cache_factory,
  104. num_tokens: int,
  105. num_heads: int,
  106. head_size: int,
  107. block_size: int,
  108. num_blocks: int,
  109. dtype: torch.dtype,
  110. seed: int,
  111. device: str,
  112. kv_cache_dtype: str,
  113. ) -> None:
  114. if kv_cache_dtype == "fp8" and head_size % 16:
  115. pytest.skip()
  116. random.seed(seed)
  117. torch.random.manual_seed(seed)
  118. if torch.cuda.is_available():
  119. torch.cuda.manual_seed(seed)
  120. torch.set_default_device(device)
  121. # Create a random slot mapping.
  122. num_slots = block_size * num_blocks
  123. slot_mapping_lst = random.sample(range(num_slots), num_tokens)
  124. slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
  125. qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
  126. _, key, value = qkv.unbind(dim=1)
  127. # Create the KV caches.
  128. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
  129. num_heads, head_size,
  130. kv_cache_dtype, dtype, seed,
  131. device)
  132. key_cache, value_cache = key_caches[0], value_caches[0]
  133. # Clone the KV caches.
  134. if kv_cache_dtype == "fp8":
  135. cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  136. ops.convert_fp8(cloned_key_cache, key_cache)
  137. cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  138. ops.convert_fp8(cloned_value_cache, value_cache)
  139. else:
  140. cloned_key_cache = key_cache.clone()
  141. cloned_value_cache = value_cache.clone()
  142. # Using default kv_scale
  143. k_scale = v_scale = 1.0
  144. # Call the reshape_and_cache kernel.
  145. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
  146. kv_cache_dtype, k_scale, v_scale)
  147. if kv_cache_dtype == "fp8":
  148. result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  149. ops.convert_fp8(result_key_cache, key_cache)
  150. result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  151. ops.convert_fp8(result_value_cache, value_cache)
  152. # Run the reference implementation.
  153. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
  154. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
  155. block_indicies_lst = block_indicies.cpu().tolist()
  156. block_offsets = slot_mapping % block_size
  157. block_offsets_lst = block_offsets.cpu().tolist()
  158. for i in range(num_tokens):
  159. block_idx = block_indicies_lst[i]
  160. block_offset = block_offsets_lst[i]
  161. cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
  162. cloned_value_cache[block_idx, :, :, block_offset] = value[i]
  163. if kv_cache_dtype == "fp8":
  164. torch.testing.assert_close(result_key_cache,
  165. cloned_key_cache,
  166. atol=0.001,
  167. rtol=0.1)
  168. torch.testing.assert_close(result_value_cache,
  169. cloned_value_cache,
  170. atol=0.001,
  171. rtol=0.1)
  172. else:
  173. torch.testing.assert_close(key_cache, cloned_key_cache)
  174. torch.testing.assert_close(value_cache, cloned_value_cache)
  175. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  176. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  177. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  178. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  179. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  180. @pytest.mark.parametrize("dtype", DTYPES)
  181. @pytest.mark.parametrize("seed", SEEDS)
  182. @pytest.mark.parametrize("device", CUDA_DEVICES)
  183. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  184. @torch.inference_mode()
  185. def test_reshape_and_cache_flash(
  186. kv_cache_factory_flashinfer,
  187. num_tokens: int,
  188. num_heads: int,
  189. head_size: int,
  190. block_size: int,
  191. num_blocks: int,
  192. dtype: torch.dtype,
  193. seed: int,
  194. device: str,
  195. kv_cache_dtype: str,
  196. ) -> None:
  197. random.seed(seed)
  198. torch.random.manual_seed(seed)
  199. torch.cuda.manual_seed(seed)
  200. torch.set_default_device(device)
  201. # Create a random slot mapping.
  202. num_slots = block_size * num_blocks
  203. slot_mapping_lst = random.sample(range(num_slots), num_tokens)
  204. slot_mapping = torch.tensor(slot_mapping_lst,
  205. dtype=torch.long,
  206. device=device)
  207. qkv = torch.randn(num_tokens,
  208. 3,
  209. num_heads,
  210. head_size,
  211. dtype=dtype,
  212. device=device)
  213. _, key, value = qkv.unbind(dim=1)
  214. # Create the KV caches.
  215. key_caches, value_caches = kv_cache_factory_flashinfer(
  216. num_blocks,
  217. block_size,
  218. 1,
  219. num_heads,
  220. head_size,
  221. kv_cache_dtype,
  222. dtype,
  223. device=device,
  224. )
  225. key_cache, value_cache = key_caches[0].contiguous(
  226. ), value_caches[0].contiguous()
  227. del key_caches
  228. del value_caches
  229. # Clone the KV caches.
  230. if kv_cache_dtype == "fp8":
  231. cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  232. ops.convert_fp8(cloned_key_cache, key_cache)
  233. cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  234. ops.convert_fp8(cloned_value_cache, value_cache)
  235. else:
  236. cloned_key_cache = key_cache.clone()
  237. cloned_value_cache = value_cache.clone()
  238. # Using default kv_scale
  239. k_scale = v_scale = 1.0
  240. # Call the reshape_and_cache kernel.
  241. ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
  242. slot_mapping, kv_cache_dtype, k_scale, v_scale)
  243. if kv_cache_dtype == "fp8":
  244. result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  245. ops.convert_fp8(result_key_cache, key_cache)
  246. result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  247. ops.convert_fp8(result_value_cache, value_cache)
  248. # Run the reference implementation.
  249. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
  250. block_indicies_lst = block_indicies.cpu().tolist()
  251. block_offsets = slot_mapping % block_size
  252. block_offsets_lst = block_offsets.cpu().tolist()
  253. for i in range(num_tokens):
  254. block_idx = block_indicies_lst[i]
  255. block_offset = block_offsets_lst[i]
  256. cloned_key_cache[block_idx, block_offset, :, :] = key[i]
  257. cloned_value_cache[block_idx, block_offset, :, :] = value[i]
  258. if kv_cache_dtype == "fp8":
  259. torch.testing.assert_close(result_key_cache,
  260. cloned_key_cache,
  261. atol=0.001,
  262. rtol=0.1)
  263. torch.testing.assert_close(result_value_cache,
  264. cloned_value_cache,
  265. atol=0.001,
  266. rtol=0.1)
  267. else:
  268. torch.testing.assert_close(key_cache, cloned_key_cache)
  269. torch.testing.assert_close(value_cache, cloned_value_cache)
  270. @pytest.mark.parametrize("direction", COPYING_DIRECTION)
  271. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  272. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  273. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  274. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  275. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  276. @pytest.mark.parametrize("dtype", DTYPES)
  277. @pytest.mark.parametrize("seed", SEEDS)
  278. @pytest.mark.parametrize("device", CUDA_DEVICES)
  279. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  280. @torch.inference_mode()
  281. def test_swap_blocks(
  282. kv_cache_factory,
  283. direction: Tuple[str, str],
  284. num_mappings: int,
  285. num_heads: int,
  286. head_size: int,
  287. block_size: int,
  288. num_blocks: int,
  289. dtype: torch.dtype,
  290. seed: int,
  291. device: str,
  292. kv_cache_dtype: str,
  293. ) -> None:
  294. if kv_cache_dtype == "fp8" and "cpu" in direction:
  295. pytest.skip()
  296. if kv_cache_dtype == "fp8" and head_size % 16:
  297. pytest.skip()
  298. random.seed(seed)
  299. torch.random.manual_seed(seed)
  300. if torch.cuda.is_available():
  301. torch.cuda.manual_seed(seed)
  302. src_device = device if direction[0] == "cuda" else 'cpu'
  303. dst_device = device if direction[1] == "cuda" else 'cpu'
  304. src_blocks = random.sample(range(num_blocks), num_mappings)
  305. # For the same device, mapping must not overlap
  306. if src_device == dst_device:
  307. remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
  308. dst_blocks = random.sample(remaining_blocks, num_mappings)
  309. else:
  310. dst_blocks = random.sample(range(num_blocks), num_mappings)
  311. block_mapping = list(zip(src_blocks, dst_blocks))
  312. block_mapping_tensor = torch.tensor(block_mapping,
  313. dtype=torch.int64,
  314. device="cpu").view(-1, 2)
  315. # Create the KV caches on the first device.
  316. src_key_caches, src_value_caches = kv_cache_factory(
  317. num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
  318. seed, src_device)
  319. # Create the KV caches on the second device.
  320. dist_key_caches, dist_value_caches = kv_cache_factory(
  321. num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
  322. seed, dst_device)
  323. src_key_caches_clone = src_key_caches[0].clone()
  324. src_value_caches_clone = src_value_caches[0].clone()
  325. # Call the swap_blocks kernel.
  326. ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
  327. block_mapping_tensor)
  328. ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
  329. block_mapping_tensor)
  330. for src, dst in block_mapping:
  331. torch.testing.assert_close(src_key_caches_clone[src].cpu(),
  332. dist_key_caches[0][dst].cpu())
  333. torch.testing.assert_close(src_value_caches_clone[src].cpu(),
  334. dist_value_caches[0][dst].cpu())
  335. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  336. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  337. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  338. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  339. @pytest.mark.parametrize("dtype", DTYPES)
  340. @pytest.mark.parametrize("seed", SEEDS)
  341. @pytest.mark.parametrize("device", CUDA_DEVICES)
  342. @torch.inference_mode()
  343. def test_fp8_e4m3_conversion(
  344. num_heads: int,
  345. head_size: int,
  346. block_size: int,
  347. num_blocks: int,
  348. dtype: torch.dtype,
  349. seed: int,
  350. device: str,
  351. ) -> None:
  352. random.seed(seed)
  353. torch.random.manual_seed(seed)
  354. torch.cuda.manual_seed(seed)
  355. low = -224.0
  356. high = 224.0
  357. shape = (num_blocks, num_heads, head_size, block_size)
  358. cache = torch.empty(shape, dtype=dtype, device=device)
  359. cache.uniform_(low, high)
  360. cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
  361. ops.convert_fp8(cache_fp8, cache)
  362. converted_cache = torch.empty_like(cache)
  363. ops.convert_fp8(converted_cache, cache_fp8)
  364. torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)