test_cache.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import random
  2. from typing import Tuple
  3. import pytest
  4. import torch
  5. from aphrodite._C import ops
  6. from aphrodite.common.utils import is_hip
  7. COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
  8. DTYPES = [torch.half, torch.bfloat16, torch.float]
  9. NUM_TOKENS = [42] # Arbitrary values for testing
  10. NUM_LAYERS = [1] # Arbitrary values for testing
  11. NUM_HEADS = [8] # Arbitrary values for testing
  12. HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  13. BLOCK_SIZES = [8, 16, 32]
  14. # Arbitrary values for testing
  15. # don't make it too large. e.g. [1024, 36000] will OOM
  16. NUM_BLOCKS = [1024, 10000]
  17. NUM_MAPPINGS = [256] # Arbitrary values for testing
  18. SEEDS = [0]
  19. CUDA_DEVICES = [
  20. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  21. ]
  22. KV_CACHE_DTYPE = ["auto", "fp8"]
  23. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  24. @pytest.mark.parametrize("num_layers", NUM_LAYERS)
  25. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  26. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  27. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  28. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  29. @pytest.mark.parametrize("dtype", DTYPES)
  30. @pytest.mark.parametrize("seed", SEEDS)
  31. @pytest.mark.parametrize("device", CUDA_DEVICES)
  32. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  33. @torch.inference_mode()
  34. def test_copy_blocks(
  35. kv_cache_factory,
  36. num_mappings: int,
  37. num_layers: int,
  38. num_heads: int,
  39. head_size: int,
  40. block_size: int,
  41. num_blocks: int,
  42. dtype: torch.dtype,
  43. seed: int,
  44. kv_cache_dtype: str,
  45. device: str,
  46. ) -> None:
  47. random.seed(seed)
  48. torch.random.manual_seed(seed)
  49. if torch.cuda.is_available():
  50. torch.cuda.manual_seed(seed)
  51. torch.set_default_device(device)
  52. # Generate random block mappings where each source block is mapped to two
  53. # destination blocks.
  54. assert 2 * num_mappings <= num_blocks
  55. src_blocks = random.sample(range(num_blocks), num_mappings)
  56. remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
  57. dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
  58. block_mapping = {}
  59. for i in range(num_mappings):
  60. src = src_blocks[i]
  61. dst1 = dst_blocks[2 * i]
  62. dst2 = dst_blocks[2 * i + 1]
  63. block_mapping[src] = [dst1, dst2]
  64. # Create the KV caches.
  65. key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
  66. num_layers, num_heads,
  67. head_size, kv_cache_dtype,
  68. dtype, seed, device)
  69. # Clone the KV caches.
  70. cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
  71. cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
  72. # Call the copy blocks kernel.
  73. ops.copy_blocks(key_caches, value_caches, block_mapping)
  74. # Run the reference implementation.
  75. for src, dsts in block_mapping.items():
  76. for dst in dsts:
  77. for cloned_key_cache in cloned_key_caches:
  78. cloned_key_cache[dst].copy_(cloned_key_cache[src])
  79. for cloned_value_cache in cloned_value_caches:
  80. cloned_value_cache[dst].copy_(cloned_value_cache[src])
  81. # Compare the results.
  82. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
  83. assert torch.allclose(key_cache, cloned_key_cache)
  84. for value_cache, cloned_value_cache in zip(value_caches,
  85. cloned_value_caches):
  86. assert torch.allclose(value_cache, cloned_value_cache)
  87. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  88. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  89. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  90. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  91. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  92. @pytest.mark.parametrize("dtype", DTYPES)
  93. @pytest.mark.parametrize("seed", SEEDS)
  94. @pytest.mark.parametrize("device", CUDA_DEVICES)
  95. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  96. @torch.inference_mode()
  97. def test_reshape_and_cache(
  98. kv_cache_factory,
  99. num_tokens: int,
  100. num_heads: int,
  101. head_size: int,
  102. block_size: int,
  103. num_blocks: int,
  104. dtype: torch.dtype,
  105. seed: int,
  106. device: str,
  107. kv_cache_dtype: str,
  108. ) -> None:
  109. if not is_hip() and kv_cache_dtype == "fp8":
  110. pytest.skip() # This test is not tuned for e5m2 cuda precision
  111. random.seed(seed)
  112. torch.random.manual_seed(seed)
  113. if torch.cuda.is_available():
  114. torch.cuda.manual_seed(seed)
  115. torch.set_default_device(device)
  116. # Create a random slot mapping.
  117. num_slots = block_size * num_blocks
  118. slot_mapping = random.sample(range(num_slots), num_tokens)
  119. slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
  120. qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
  121. _, key, value = qkv.unbind(dim=1)
  122. # Create the KV caches.
  123. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
  124. num_heads, head_size,
  125. kv_cache_dtype, dtype, seed,
  126. device)
  127. key_cache, value_cache = key_caches[0], value_caches[0]
  128. # Clone the KV caches.
  129. if kv_cache_dtype == "fp8":
  130. cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  131. ops.convert_fp8(key_cache, cloned_key_cache)
  132. cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  133. ops.convert_fp8(value_cache, cloned_value_cache)
  134. else:
  135. cloned_key_cache = key_cache.clone()
  136. cloned_value_cache = value_cache.clone()
  137. # Using default kv_scale
  138. kv_scale = 1.0
  139. # Call the reshape_and_cache kernel.
  140. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
  141. kv_cache_dtype, kv_scale)
  142. if kv_cache_dtype == "fp8":
  143. result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
  144. ops.convert_fp8(key_cache, result_key_cache)
  145. result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
  146. ops.convert_fp8(value_cache, result_value_cache)
  147. # Run the reference implementation.
  148. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
  149. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
  150. block_indicies = block_indicies.cpu().tolist()
  151. block_offsets = slot_mapping % block_size
  152. block_offsets = block_offsets.cpu().tolist()
  153. for i in range(num_tokens):
  154. block_idx = block_indicies[i]
  155. block_offset = block_offsets[i]
  156. cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
  157. cloned_value_cache[block_idx, :, :, block_offset] = value[i]
  158. if kv_cache_dtype == "fp8":
  159. assert torch.allclose(result_key_cache,
  160. cloned_key_cache,
  161. atol=0.001,
  162. rtol=0.1)
  163. assert torch.allclose(result_value_cache,
  164. cloned_value_cache,
  165. atol=0.001,
  166. rtol=0.1)
  167. else:
  168. assert torch.allclose(key_cache, cloned_key_cache)
  169. assert torch.allclose(value_cache, cloned_value_cache)
  170. @pytest.mark.parametrize("direction", COPYING_DIRECTION)
  171. @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
  172. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  173. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  174. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  175. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  176. @pytest.mark.parametrize("dtype", DTYPES)
  177. @pytest.mark.parametrize("seed", SEEDS)
  178. @pytest.mark.parametrize("device", CUDA_DEVICES)
  179. @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
  180. @torch.inference_mode()
  181. def test_swap_blocks(
  182. kv_cache_factory,
  183. direction: Tuple[str, str],
  184. num_mappings: int,
  185. num_heads: int,
  186. head_size: int,
  187. block_size: int,
  188. num_blocks: int,
  189. dtype: torch.dtype,
  190. seed: int,
  191. device: str,
  192. kv_cache_dtype: str,
  193. ) -> None:
  194. if kv_cache_dtype == "fp8" and "cpu" in direction:
  195. pytest.skip()
  196. if not is_hip() and kv_cache_dtype == "fp8":
  197. pytest.skip() # This test is not tuned for e5m2 cuda precision
  198. random.seed(seed)
  199. torch.random.manual_seed(seed)
  200. if torch.cuda.is_available():
  201. torch.cuda.manual_seed(seed)
  202. src_device = device if direction[0] == "cuda" else 'cpu'
  203. dst_device = device if direction[1] == "cuda" else 'cpu'
  204. src_blocks = random.sample(range(num_blocks), num_mappings)
  205. # For the same device, mapping must not overlap
  206. if src_device == dst_device:
  207. remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
  208. dst_blocks = random.sample(remaining_blocks, num_mappings)
  209. else:
  210. dst_blocks = random.sample(range(num_blocks), num_mappings)
  211. block_mapping = dict(zip(src_blocks, dst_blocks))
  212. # Create the KV caches on the first device.
  213. src_key_caches, src_value_caches = kv_cache_factory(
  214. num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
  215. seed, src_device)
  216. # Create the KV caches on the second device.
  217. dist_key_caches, dist_value_caches = kv_cache_factory(
  218. num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
  219. seed, dst_device)
  220. src_key_caches_clone = src_key_caches[0].clone()
  221. src_value_caches_clone = src_value_caches[0].clone()
  222. # Call the swap_blocks kernel.
  223. ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
  224. ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
  225. for src, dst in block_mapping.items():
  226. assert torch.allclose(src_key_caches_clone[src].cpu(),
  227. dist_key_caches[0][dst].cpu())
  228. assert torch.allclose(src_value_caches_clone[src].cpu(),
  229. dist_value_caches[0][dst].cpu())
  230. @pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
  231. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  232. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  233. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  234. @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
  235. @pytest.mark.parametrize("dtype", DTYPES)
  236. @pytest.mark.parametrize("seed", SEEDS)
  237. @pytest.mark.parametrize("device", CUDA_DEVICES)
  238. @torch.inference_mode()
  239. def test_fp8_conversion(
  240. num_heads: int,
  241. head_size: int,
  242. block_size: int,
  243. num_blocks: int,
  244. dtype: torch.dtype,
  245. seed: int,
  246. device: str,
  247. ) -> None:
  248. random.seed(seed)
  249. torch.random.manual_seed(seed)
  250. torch.cuda.manual_seed(seed)
  251. low = -224.0
  252. high = 224.0
  253. shape = (num_blocks, num_heads, head_size, block_size)
  254. cache = torch.empty(shape, dtype=dtype, device=device)
  255. cache.uniform_(low, high)
  256. cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
  257. ops.convert_fp8(cache, cache_fp8)
  258. converted_cache = torch.empty_like(cache)
  259. ops.convert_fp8(cache_fp8, converted_cache)
  260. assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)