cache_kernels.cu 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #ifdef ENABLE_FP8_E5M2
  7. #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  8. #endif
  9. #include <algorithm>
  10. #include <cassert>
  11. #include <map>
  12. #include <vector>
  13. #ifdef USE_ROCM
  14. #include <hip/hip_bf16.h>
  15. typedef __hip_bfloat16 __nv_bfloat16;
  16. #endif
  17. void swap_blocks(
  18. torch::Tensor& src,
  19. torch::Tensor& dst,
  20. const std::map<int64_t, int64_t>& block_mapping) {
  21. torch::Device src_device = src.device();
  22. torch::Device dst_device = dst.device();
  23. cudaMemcpyKind memcpy_type;
  24. if (src_device.is_cuda() && dst_device.is_cuda()) {
  25. TORCH_CHECK(
  26. src_device.index() == dst_device.index(),
  27. "src and dst must be on the same GPU");
  28. memcpy_type = cudaMemcpyDeviceToDevice;
  29. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  30. memcpy_type = cudaMemcpyDeviceToHost;
  31. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  32. memcpy_type = cudaMemcpyHostToDevice;
  33. } else {
  34. TORCH_CHECK(false, "Invalid device combination");
  35. }
  36. char *src_ptr = static_cast<char*>(src.data_ptr());
  37. char *dst_ptr = static_cast<char*>(dst.data_ptr());
  38. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  39. const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
  40. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  41. // NOTE: This can be slow if the number of blocks is large.
  42. for (const auto& pair : block_mapping) {
  43. int64_t src_block_number = pair.first;
  44. int64_t dst_block_number = pair.second;
  45. int64_t src_offset = src_block_number * block_size_in_bytes;
  46. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  47. cudaMemcpyAsync(
  48. dst_ptr + dst_offset,
  49. src_ptr + src_offset,
  50. block_size_in_bytes,
  51. memcpy_type,
  52. stream);
  53. }
  54. }
  55. namespace aphrodite {
  56. // Grid: (num_layers, num_pairs)
  57. template<typename scalar_t>
  58. __global__ void copy_blocks_kernel(
  59. int64_t* key_cache_ptrs,
  60. int64_t* value_cache_ptrs,
  61. const int64_t* __restrict__ block_mapping,
  62. const int numel_per_block) {
  63. const int layer_idx = blockIdx.x;
  64. const int pair_idx = blockIdx.y;
  65. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  66. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  67. int64_t src_block_number = block_mapping[2 * pair_idx];
  68. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  69. const int64_t src_block_offset = src_block_number * numel_per_block;
  70. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  71. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  72. int64_t src_offset = src_block_offset + i;
  73. int64_t dst_offset = dst_block_offset + i;
  74. key_cache[dst_offset] = key_cache[src_offset];
  75. }
  76. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  77. int64_t src_offset = src_block_offset + i;
  78. int64_t dst_offset = dst_block_offset + i;
  79. value_cache[dst_offset] = value_cache[src_offset];
  80. }
  81. }
  82. } // namespace aphrodite
  83. void copy_blocks(
  84. std::vector<torch::Tensor>& key_caches,
  85. std::vector<torch::Tensor>& value_caches,
  86. const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  87. int num_layers = key_caches.size();
  88. TORCH_CHECK(num_layers == value_caches.size());
  89. if (num_layers == 0) {
  90. return;
  91. }
  92. torch::Device cache_device = key_caches[0].device();
  93. TORCH_CHECK(cache_device.is_cuda());
  94. // Create data structures for the kernel.
  95. // Create an array of pointers to the key and value caches.
  96. int64_t key_cache_ptrs[num_layers];
  97. int64_t value_cache_ptrs[num_layers];
  98. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  99. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  100. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  101. }
  102. // Create block mapping array.
  103. std::vector<int64_t> block_mapping_vec;
  104. for (const auto& pair : block_mapping) {
  105. int64_t src_block_number = pair.first;
  106. for (int64_t dst_block_number : pair.second) {
  107. block_mapping_vec.push_back(src_block_number);
  108. block_mapping_vec.push_back(dst_block_number);
  109. }
  110. }
  111. int64_t* block_mapping_array = block_mapping_vec.data();
  112. int num_pairs = block_mapping_vec.size() / 2;
  113. // Move the data structures to the GPU.
  114. // NOTE: This synchronizes the CPU and GPU.
  115. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  116. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  117. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  118. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  119. torch::Tensor block_mapping_tensor = torch::from_blob(
  120. block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
  121. // Launch the kernel.
  122. const int numel_per_block = key_caches[0][0].numel();
  123. dim3 grid(num_layers, num_pairs);
  124. dim3 block(std::min(1024, numel_per_block));
  125. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  126. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  127. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  128. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  129. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  130. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  131. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  132. block_mapping_tensor.data_ptr<int64_t>(),
  133. numel_per_block);
  134. }));
  135. }
  136. namespace aphrodite {
  137. template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
  138. __global__ void reshape_and_cache_kernel(
  139. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  140. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  141. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  142. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  143. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  144. const int key_stride,
  145. const int value_stride,
  146. const int num_heads,
  147. const int head_size,
  148. const int block_size,
  149. const int x) {
  150. const int64_t token_idx = blockIdx.x;
  151. const int64_t slot_idx = slot_mapping[token_idx];
  152. if (slot_idx < 0) {
  153. // Padding token that should be ignored.
  154. return;
  155. }
  156. const int64_t block_idx = slot_idx / block_size;
  157. const int64_t block_offset = slot_idx % block_size;
  158. const int n = num_heads * head_size;
  159. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  160. const int64_t src_key_idx = token_idx * key_stride + i;
  161. const int64_t src_value_idx = token_idx * value_stride + i;
  162. const int head_idx = i / head_size;
  163. const int head_offset = i % head_size;
  164. const int x_idx = head_offset / x;
  165. const int x_offset = head_offset % x;
  166. const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  167. + head_idx * (head_size / x) * block_size * x
  168. + x_idx * block_size * x
  169. + block_offset * x
  170. + x_offset;
  171. const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
  172. + head_idx * head_size * block_size
  173. + head_offset * block_size
  174. + block_offset;
  175. scalar_t tgt_key = key[src_key_idx];
  176. scalar_t tgt_value = value[src_value_idx];
  177. if constexpr (is_fp8_e5m2_kv_cache) {
  178. #ifdef ENABLE_FP8_E5M2
  179. key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
  180. value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
  181. #else
  182. assert(false);
  183. #endif
  184. } else {
  185. key_cache[tgt_key_idx] = tgt_key;
  186. value_cache[tgt_value_idx] = tgt_value;
  187. }
  188. }
  189. }
  190. } // namespace aphrodite
  191. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
  192. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
  193. reinterpret_cast<KV_T*>(key.data_ptr()), \
  194. reinterpret_cast<KV_T*>(value.data_ptr()), \
  195. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  196. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  197. slot_mapping.data_ptr<int64_t>(), \
  198. key_stride, \
  199. value_stride, \
  200. num_heads, \
  201. head_size, \
  202. block_size, \
  203. x);
  204. void reshape_and_cache(
  205. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  206. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  207. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  208. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  209. torch::Tensor& slot_mapping, // [num_tokens]
  210. const std::string& kv_cache_dtype)
  211. {
  212. int num_tokens = key.size(0);
  213. int num_heads = key.size(1);
  214. int head_size = key.size(2);
  215. int block_size = key_cache.size(3);
  216. int x = key_cache.size(4);
  217. int key_stride = key.stride(0);
  218. int value_stride = value.stride(0);
  219. dim3 grid(num_tokens);
  220. dim3 block(std::min(num_heads * head_size, 512));
  221. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  222. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  223. if (kv_cache_dtype == "auto") {
  224. if (key.dtype() == at::ScalarType::Float) {
  225. CALL_RESHAPE_AND_CACHE(float, float, false);
  226. } else if (key.dtype() == at::ScalarType::Half) {
  227. CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
  228. } else if (key.dtype() == at::ScalarType::BFloat16) {
  229. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
  230. }
  231. } else if (kv_cache_dtype == "fp8_e5m2") {
  232. if (key.dtype() == at::ScalarType::Float) {
  233. CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
  234. } else if (key.dtype() == at::ScalarType::Half) {
  235. CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
  236. } else if (key.dtype() == at::ScalarType::BFloat16) {
  237. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
  238. }
  239. } else {
  240. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  241. }
  242. }
  243. namespace aphrodite {
  244. // Grid: (num_blocks, block_size).
  245. template<typename scalar_t>
  246. __global__ void gather_cached_kv_kernel(
  247. scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  248. scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  249. const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  250. const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  251. const int* __restrict__ slot_mapping, // [num_tokens]
  252. const int key_stride,
  253. const int value_stride,
  254. const int num_heads,
  255. const int head_size,
  256. const int block_size,
  257. const int x) {
  258. const int token_idx = blockIdx.x;
  259. const int slot_idx = slot_mapping[token_idx];
  260. const int block_idx = slot_idx / block_size;
  261. const int block_offset = slot_idx % block_size;
  262. const int num_tokens = num_heads * head_size;
  263. for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
  264. const int tgt_key_idx = token_idx * key_stride + i;
  265. const int tgt_value_idx = token_idx * value_stride + i;
  266. const int head_idx = i / head_size;
  267. const int head_offset = i % head_size;
  268. const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
  269. const int x_offset = head_offset % x;
  270. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  271. + head_idx * (head_size / x) * block_size * x
  272. + x_idx * block_size * x
  273. + block_offset * x
  274. + x_offset;
  275. const int src_value_idx = block_idx * num_heads * head_size * block_size
  276. + head_idx * head_size * block_size
  277. + head_offset * block_size
  278. + block_offset;
  279. key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
  280. value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
  281. }
  282. }
  283. template <typename scalar_t>
  284. __global__ void gather_cached_kv_kernel_optimized(
  285. scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
  286. scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
  287. const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  288. const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  289. const int *__restrict__ slot_mapping, // [num_tokens]
  290. const int key_stride,
  291. const int value_stride,
  292. const int num_heads,
  293. const int head_size,
  294. const int block_size,
  295. const int x)
  296. {
  297. const int token_idx = blockIdx.x;
  298. const int slot_idx = slot_mapping[token_idx];
  299. const int block_idx = slot_idx / block_size;
  300. const int block_offset = slot_idx % block_size;
  301. const int dim = num_heads * head_size;
  302. assert(dim % 4 == 0); // this is true for known use cases
  303. const int unroll_factor = 4;
  304. const int unrolled_dim = dim / unroll_factor;
  305. for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
  306. {
  307. int tgt_key_indices[unroll_factor];
  308. int tgt_value_indices[unroll_factor];
  309. int src_key_indices[unroll_factor];
  310. int src_value_indices[unroll_factor];
  311. scalar_t keys_to_store[unroll_factor];
  312. scalar_t values_to_store[unroll_factor];
  313. #pragma unroll
  314. for (int j = 0; j < unroll_factor; ++j)
  315. {
  316. int index = i + j * unrolled_dim;
  317. const int tgt_key_idx = token_idx * key_stride + index;
  318. const int tgt_value_idx = token_idx * value_stride + index;
  319. const int head_idx = index / head_size;
  320. const int head_offset = index % head_size;
  321. const int x_idx = head_offset / x;
  322. const int x_offset = head_offset % x;
  323. const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  324. + head_idx * (head_size / x) * block_size * x
  325. + x_idx * block_size * x
  326. + block_offset * x
  327. + x_offset;
  328. const int src_value_idx = block_idx * num_heads * head_size * block_size
  329. + head_idx * head_size * block_size
  330. + head_offset * block_size
  331. + block_offset;
  332. tgt_key_indices[j] = tgt_key_idx;
  333. tgt_value_indices[j] = tgt_value_idx;
  334. src_key_indices[j] = src_key_idx;
  335. src_value_indices[j] = src_value_idx;
  336. keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
  337. values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
  338. }
  339. #pragma unroll
  340. for (int j = 0; j < unroll_factor; ++j)
  341. {
  342. key[tgt_key_indices[j]] = keys_to_store[j];
  343. value[tgt_value_indices[j]] = values_to_store[j];
  344. }
  345. }
  346. }
  347. } // namespace aphrodite
  348. void gather_cached_kv(
  349. torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
  350. torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
  351. torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
  352. torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
  353. torch::Tensor& slot_mapping) // [in] [num_tokens]
  354. {
  355. int num_tokens = key.size(0);
  356. int num_heads = key.size(1);
  357. int head_size = key.size(2);
  358. int block_size = key_cache.size(3);
  359. int x = key_cache.size(4);
  360. int key_stride = key.stride(0);
  361. int value_stride = value.stride(0);
  362. dim3 grid(num_tokens);
  363. dim3 block(std::min(num_heads * head_size, 512));
  364. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  365. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  366. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  367. key.scalar_type(),
  368. "gather_cached_kv_kernel_optimized",
  369. [&] {
  370. aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
  371. key.data_ptr<scalar_t>(),
  372. value.data_ptr<scalar_t>(),
  373. key_cache.data_ptr<scalar_t>(),
  374. value_cache.data_ptr<scalar_t>(),
  375. slot_mapping.data_ptr<int>(),
  376. key_stride,
  377. value_stride,
  378. num_heads,
  379. head_size,
  380. block_size,
  381. x);
  382. });
  383. }
  384. namespace aphrodite {
  385. template<typename Tout, typename Tin>
  386. __global__ void convert_fp8_e5m2_kernel(
  387. const Tin* __restrict__ src_cache,
  388. Tout* __restrict__ dst_cache,
  389. const int64_t block_stride) {
  390. const int64_t block_idx = blockIdx.x;
  391. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  392. int64_t idx = block_idx * block_stride + i;
  393. #ifdef ENABLE_FP8_E5M2
  394. dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
  395. #else
  396. assert(false);
  397. #endif
  398. }
  399. }
  400. } // namespace aphrodite
  401. #define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
  402. aphrodite::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
  403. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  404. reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
  405. block_stride);
  406. void convert_fp8_e5m2(
  407. torch::Tensor& src_cache,
  408. torch::Tensor& dst_cache)
  409. {
  410. int64_t num_blocks = src_cache.size(0);
  411. int64_t block_stride = src_cache.stride(0);
  412. dim3 grid(num_blocks);
  413. dim3 block(std::min(block_stride, int64_t(512)));
  414. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  415. if (src_cache.dtype() == at::ScalarType::Float) {
  416. CALL_CONVERT_FP8_E5M2(uint8_t, float);
  417. } else if (src_cache.dtype() == at::ScalarType::Half) {
  418. CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
  419. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  420. CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
  421. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  422. CALL_CONVERT_FP8_E5M2(float, uint8_t);
  423. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  424. CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
  425. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  426. CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
  427. }
  428. }