cache_kernels.cu 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  7. #include <algorithm>
  8. #include <cassert>
  9. #include <map>
  10. #include <vector>
  11. void swap_blocks(
  12. torch::Tensor& src,
  13. torch::Tensor& dst,
  14. const std::map<int64_t, int64_t>& block_mapping) {
  15. torch::Device src_device = src.device();
  16. torch::Device dst_device = dst.device();
  17. cudaMemcpyKind memcpy_type;
  18. if (src_device.is_cuda() && dst_device.is_cuda()) {
  19. TORCH_CHECK(
  20. src_device.index() == dst_device.index(),
  21. "src and dst must be on the same GPU");
  22. memcpy_type = cudaMemcpyDeviceToDevice;
  23. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  24. memcpy_type = cudaMemcpyDeviceToHost;
  25. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  26. memcpy_type = cudaMemcpyHostToDevice;
  27. } else {
  28. TORCH_CHECK(false, "Invalid device combination");
  29. }
  30. char *src_ptr = static_cast<char*>(src.data_ptr());
  31. char *dst_ptr = static_cast<char*>(dst.data_ptr());
  32. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  33. const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
  34. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  35. // NOTE: This can be slow if the number of blocks is large.
  36. for (const auto& pair : block_mapping) {
  37. int64_t src_block_number = pair.first;
  38. int64_t dst_block_number = pair.second;
  39. int64_t src_offset = src_block_number * block_size_in_bytes;
  40. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  41. cudaMemcpyAsync(
  42. dst_ptr + dst_offset,
  43. src_ptr + src_offset,
  44. block_size_in_bytes,
  45. memcpy_type,
  46. stream);
  47. }
  48. }
  49. namespace aphrodite {
  50. // Grid: (num_layers, num_pairs)
  51. template<typename scalar_t>
  52. __global__ void copy_blocks_kernel(
  53. int64_t* key_cache_ptrs,
  54. int64_t* value_cache_ptrs,
  55. const int64_t* __restrict__ block_mapping,
  56. const int numel_per_block) {
  57. const int layer_idx = blockIdx.x;
  58. const int pair_idx = blockIdx.y;
  59. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  60. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  61. int64_t src_block_number = block_mapping[2 * pair_idx];
  62. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  63. const int64_t src_block_offset = src_block_number * numel_per_block;
  64. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  65. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  66. int64_t src_offset = src_block_offset + i;
  67. int64_t dst_offset = dst_block_offset + i;
  68. key_cache[dst_offset] = key_cache[src_offset];
  69. }
  70. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  71. int64_t src_offset = src_block_offset + i;
  72. int64_t dst_offset = dst_block_offset + i;
  73. value_cache[dst_offset] = value_cache[src_offset];
  74. }
  75. }
  76. } // namespace aphrodite
  77. void copy_blocks(
  78. std::vector<torch::Tensor>& key_caches,
  79. std::vector<torch::Tensor>& value_caches,
  80. const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  81. int num_layers = key_caches.size();
  82. TORCH_CHECK(num_layers == value_caches.size());
  83. if (num_layers == 0) {
  84. return;
  85. }
  86. torch::Device cache_device = key_caches[0].device();
  87. TORCH_CHECK(cache_device.is_cuda());
  88. // Create data structures for the kernel.
  89. // Create an array of pointers to the key and value caches.
  90. int64_t key_cache_ptrs[num_layers];
  91. int64_t value_cache_ptrs[num_layers];
  92. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  93. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  94. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  95. }
  96. // Create block mapping array.
  97. std::vector<int64_t> block_mapping_vec;
  98. for (const auto& pair : block_mapping) {
  99. int64_t src_block_number = pair.first;
  100. for (int64_t dst_block_number : pair.second) {
  101. block_mapping_vec.push_back(src_block_number);
  102. block_mapping_vec.push_back(dst_block_number);
  103. }
  104. }
  105. int64_t* block_mapping_array = block_mapping_vec.data();
  106. int num_pairs = block_mapping_vec.size() / 2;
  107. // Move the data structures to the GPU.
  108. // NOTE: This synchronizes the CPU and GPU.
  109. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  110. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  111. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  112. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  113. torch::Tensor block_mapping_tensor = torch::from_blob(
  114. block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
  115. // Launch the kernel.
  116. const int numel_per_block = key_caches[0][0].numel();
  117. dim3 grid(num_layers, num_pairs);
  118. dim3 block(std::min(1024, numel_per_block));
  119. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  120. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  121. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  122. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  123. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  124. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  125. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  126. block_mapping_tensor.data_ptr<int64_t>(),
  127. numel_per_block);
  128. }));
  129. }
  130. namespace aphrodite {
  131. template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
  132. __global__ void reshape_and_cache_kernel(
  133. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  134. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  135. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  136. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  137. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  138. const int key_stride,
  139. const int value_stride,
  140. const int num_heads,
  141. const int head_size,
  142. const int block_size,
  143. const int x) {
  144. const int64_t token_idx = blockIdx.x;
  145. const int64_t slot_idx = slot_mapping[token_idx];
  146. if (slot_idx < 0) {
  147. // Padding token that should be ignored.
  148. return;
  149. }
  150. const int64_t block_idx = slot_idx / block_size;
  151. const int64_t block_offset = slot_idx % block_size;
  152. const int n = num_heads * head_size;
  153. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  154. const int64_t src_key_idx = token_idx * key_stride + i;
  155. const int64_t src_value_idx = token_idx * value_stride + i;
  156. const int head_idx = i / head_size;
  157. const int head_offset = i % head_size;
  158. const int x_idx = head_offset / x;
  159. const int x_offset = head_offset % x;
  160. const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  161. + head_idx * (head_size / x) * block_size * x
  162. + x_idx * block_size * x
  163. + block_offset * x
  164. + x_offset;
  165. const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
  166. + head_idx * head_size * block_size
  167. + head_offset * block_size
  168. + block_offset;
  169. scalar_t tgt_key = key[src_key_idx];
  170. scalar_t tgt_value = value[src_value_idx];
  171. if constexpr (is_fp8_e5m2_kv_cache) {
  172. #ifdef ENABLE_FP8_E5M2
  173. key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
  174. value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
  175. #else
  176. assert(false);
  177. #endif
  178. } else {
  179. key_cache[tgt_key_idx] = tgt_key;
  180. value_cache[tgt_value_idx] = tgt_value;
  181. }
  182. }
  183. }
  184. } // namespace aphrodite
  185. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
  186. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
  187. reinterpret_cast<KV_T*>(key.data_ptr()), \
  188. reinterpret_cast<KV_T*>(value.data_ptr()), \
  189. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  190. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  191. slot_mapping.data_ptr<int64_t>(), \
  192. key_stride, \
  193. value_stride, \
  194. num_heads, \
  195. head_size, \
  196. block_size, \
  197. x);
  198. void reshape_and_cache(
  199. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  200. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  201. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  202. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  203. torch::Tensor& slot_mapping, // [num_tokens]
  204. const std::string& kv_cache_dtype)
  205. {
  206. int num_tokens = key.size(0);
  207. int num_heads = key.size(1);
  208. int head_size = key.size(2);
  209. int block_size = key_cache.size(3);
  210. int x = key_cache.size(4);
  211. int key_stride = key.stride(0);
  212. int value_stride = value.stride(0);
  213. dim3 grid(num_tokens);
  214. dim3 block(std::min(num_heads * head_size, 512));
  215. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  216. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  217. if (kv_cache_dtype == "auto") {
  218. if (key.dtype() == at::ScalarType::Float) {
  219. CALL_RESHAPE_AND_CACHE(float, float, false);
  220. } else if (key.dtype() == at::ScalarType::Half) {
  221. CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
  222. } else if (key.dtype() == at::ScalarType::BFloat16) {
  223. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
  224. }
  225. } else if (kv_cache_dtype == "fp8_e5m2") {
  226. if (key.dtype() == at::ScalarType::Float) {
  227. CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
  228. } else if (key.dtype() == at::ScalarType::Half) {
  229. CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
  230. } else if (key.dtype() == at::ScalarType::BFloat16) {
  231. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
  232. }
  233. } else {
  234. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  235. }
  236. }
  237. namespace aphrodite {
  238. // Grid: (num_blocks, block_size).
  239. template<typename scalar_t>
  240. __global__ void gather_cached_kv_kernel(
  241. scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  242. scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  243. const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  244. const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  245. const int* __restrict__ slot_mapping, // [num_tokens]
  246. const int key_stride,
  247. const int value_stride,
  248. const int num_heads,
  249. const int head_size,
  250. const int block_size,
  251. const int x) {
  252. const int token_idx = blockIdx.x;
  253. const int slot_idx = slot_mapping[token_idx];
  254. const int block_idx = slot_idx / block_size;
  255. const int block_offset = slot_idx % block_size;
  256. const int num_tokens = num_heads * head_size;
  257. for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
  258. const int tgt_key_idx = token_idx * key_stride + i;
  259. const int tgt_value_idx = token_idx * value_stride + i;
  260. const int head_idx = i / head_size;
  261. const int head_offset = i % head_size;
  262. const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
  263. const int x_offset = head_offset % x;
  264. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  265. + head_idx * (head_size / x) * block_size * x
  266. + x_idx * block_size * x
  267. + block_offset * x
  268. + x_offset;
  269. const int src_value_idx = block_idx * num_heads * head_size * block_size
  270. + head_idx * head_size * block_size
  271. + head_offset * block_size
  272. + block_offset;
  273. key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
  274. value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
  275. }
  276. }
  277. template <typename scalar_t>
  278. __global__ void gather_cached_kv_kernel_optimized(
  279. scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  280. scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  281. const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  282. const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  283. const int *__restrict__ slot_mapping, // [num_tokens]
  284. const int key_stride,
  285. const int value_stride,
  286. const int num_heads,
  287. const int head_size,
  288. const int block_size,
  289. const int x)
  290. {
  291. const int token_idx = blockIdx.x;
  292. const int slot_idx = slot_mapping[token_idx];
  293. const int block_idx = slot_idx / block_size;
  294. const int block_offset = slot_idx % block_size;
  295. const int dim = num_heads * head_size;
  296. assert(dim % 4 == 0); // this is true for known use cases
  297. const int unroll_factor = 4;
  298. const int unrolled_dim = dim / unroll_factor;
  299. for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
  300. {
  301. int tgt_key_indices[unroll_factor];
  302. int tgt_value_indices[unroll_factor];
  303. int src_key_indices[unroll_factor];
  304. int src_value_indices[unroll_factor];
  305. scalar_t keys_to_store[unroll_factor];
  306. scalar_t values_to_store[unroll_factor];
  307. #pragma unroll
  308. for (int j = 0; j < unroll_factor; ++j)
  309. {
  310. int index = i + j * unrolled_dim;
  311. const int tgt_key_idx = token_idx * key_stride + index;
  312. const int tgt_value_idx = token_idx * value_stride + index;
  313. const int head_idx = index / head_size;
  314. const int head_offset = index % head_size;
  315. const int x_idx = head_offset / x;
  316. const int x_offset = head_offset % x;
  317. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  318. + head_idx * (head_size / x) * block_size * x
  319. + x_idx * block_size * x
  320. + block_offset * x
  321. + x_offset;
  322. const int src_value_idx = block_idx * num_heads * head_size * block_size
  323. + head_idx * head_size * block_size
  324. + head_offset * block_size
  325. + block_offset;
  326. tgt_key_indices[j] = tgt_key_idx;
  327. tgt_value_indices[j] = tgt_value_idx;
  328. src_key_indices[j] = src_key_idx;
  329. src_value_indices[j] = src_value_idx;
  330. keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
  331. values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
  332. }
  333. #pragma unroll
  334. for (int j = 0; j < unroll_factor; ++j)
  335. {
  336. key[tgt_key_indices[j]] = keys_to_store[j];
  337. value[tgt_value_indices[j]] = values_to_store[j];
  338. }
  339. }
  340. }
  341. } // namespace aphrodite
  342. void gather_cached_kv(
  343. torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
  344. torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
  345. torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
  346. torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
  347. torch::Tensor& slot_mapping) // [in] [num_tokens]
  348. {
  349. int num_tokens = key.size(0);
  350. int num_heads = key.size(1);
  351. int head_size = key.size(2);
  352. int block_size = key_cache.size(3);
  353. int x = key_cache.size(4);
  354. int key_stride = key.stride(0);
  355. int value_stride = value.stride(0);
  356. dim3 grid(num_tokens);
  357. dim3 block(std::min(num_heads * head_size, 512));
  358. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  359. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  360. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  361. key.scalar_type(),
  362. "gather_cached_kv_kernel_optimized",
  363. [&] {
  364. aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
  365. key.data_ptr<scalar_t>(),
  366. value.data_ptr<scalar_t>(),
  367. key_cache.data_ptr<scalar_t>(),
  368. value_cache.data_ptr<scalar_t>(),
  369. slot_mapping.data_ptr<int>(),
  370. key_stride,
  371. value_stride,
  372. num_heads,
  373. head_size,
  374. block_size,
  375. x);
  376. });
  377. }
  378. namespace aphrodite {
  379. template<typename Tout, typename Tin>
  380. __global__ void convert_fp8_e5m2_kernel(
  381. const Tin* __restrict__ src_cache,
  382. Tout* __restrict__ dst_cache,
  383. const int64_t block_stride) {
  384. const int64_t block_idx = blockIdx.x;
  385. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  386. int64_t idx = block_idx * block_stride + i;
  387. #ifdef ENABLE_FP8_E5M2
  388. dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
  389. #else
  390. assert(false);
  391. #endif
  392. }
  393. }
  394. } // namespace aphrodite
  395. #define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
  396. aphrodite::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
  397. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  398. reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
  399. block_stride);
  400. void convert_fp8_e5m2(
  401. torch::Tensor& src_cache,
  402. torch::Tensor& dst_cache)
  403. {
  404. int64_t num_blocks = src_cache.size(0);
  405. int64_t block_stride = src_cache.stride(0);
  406. dim3 grid(num_blocks);
  407. dim3 block(std::min(block_stride, int64_t(512)));
  408. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  409. if (src_cache.dtype() == at::ScalarType::Float) {
  410. CALL_CONVERT_FP8_E5M2(uint8_t, float);
  411. } else if (src_cache.dtype() == at::ScalarType::Half) {
  412. CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
  413. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  414. CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
  415. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  416. CALL_CONVERT_FP8_E5M2(float, uint8_t);
  417. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  418. CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
  419. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  420. CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
  421. }
  422. }