1
0

cache_kernels.cu 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #ifdef USE_ROCM
  7. #include "quantization/fp8/amd/quant_utils.cuh"
  8. #else
  9. #include "quantization/fp8/nvidia/quant_utils.cuh"
  10. #endif
  11. #include <algorithm>
  12. #include <cassert>
  13. #include <map>
  14. #include <vector>
  15. #ifdef USE_ROCM
  16. #include <hip/hip_bf16.h>
  17. typedef __hip_bfloat16 __nv_bfloat16;
  18. #endif
  19. void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
  20. const torch::Tensor& block_mapping) {
  21. torch::Device src_device = src.device();
  22. torch::Device dst_device = dst.device();
  23. cudaMemcpyKind memcpy_type;
  24. if (src_device.is_cuda() && dst_device.is_cuda()) {
  25. TORCH_CHECK(src_device.index() == dst_device.index(),
  26. "src and dst must be on the same GPU");
  27. memcpy_type = cudaMemcpyDeviceToDevice;
  28. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  29. memcpy_type = cudaMemcpyDeviceToHost;
  30. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  31. memcpy_type = cudaMemcpyHostToDevice;
  32. } else {
  33. TORCH_CHECK(false, "Invalid device combination");
  34. }
  35. // NOTE(youkaichao): keep in mind that `block_mapping` should be
  36. // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  37. // synchronization.
  38. TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
  39. char* src_ptr = static_cast<char*>(src.data_ptr());
  40. char* dst_ptr = static_cast<char*>(dst.data_ptr());
  41. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  42. const at::cuda::OptionalCUDAGuard device_guard(
  43. src_device.is_cuda() ? src_device : dst_device);
  44. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  45. // NOTE(woosuk): This can be slow if the number of blocks is large.
  46. const int64_t num_blocks = block_mapping.size(0);
  47. for (size_t i = 0; i < num_blocks; i++) {
  48. int64_t src_block_number = block_mapping[i][0].item<int64_t>();
  49. int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
  50. int64_t src_offset = src_block_number * block_size_in_bytes;
  51. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  52. cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
  53. block_size_in_bytes, memcpy_type, stream);
  54. }
  55. }
  56. namespace aphrodite {
  57. // Grid: (num_layers, num_pairs)
  58. template <typename scalar_t>
  59. __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
  60. int64_t* value_cache_ptrs,
  61. const int64_t* __restrict__ block_mapping,
  62. const int numel_per_block) {
  63. const int layer_idx = blockIdx.x;
  64. const int pair_idx = blockIdx.y;
  65. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  66. scalar_t* value_cache =
  67. reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  68. int64_t src_block_number = block_mapping[2 * pair_idx];
  69. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  70. const int64_t src_block_offset = src_block_number * numel_per_block;
  71. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  72. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  73. int64_t src_offset = src_block_offset + i;
  74. int64_t dst_offset = dst_block_offset + i;
  75. key_cache[dst_offset] = key_cache[src_offset];
  76. }
  77. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  78. int64_t src_offset = src_block_offset + i;
  79. int64_t dst_offset = dst_block_offset + i;
  80. value_cache[dst_offset] = value_cache[src_offset];
  81. }
  82. }
  83. } // namespace aphrodite
  84. void copy_blocks(std::vector<torch::Tensor>& key_caches,
  85. std::vector<torch::Tensor>& value_caches,
  86. const torch::Tensor& block_mapping) {
  87. int num_layers = key_caches.size();
  88. TORCH_CHECK(num_layers == value_caches.size());
  89. if (num_layers == 0) {
  90. return;
  91. }
  92. torch::Device cache_device = key_caches[0].device();
  93. TORCH_CHECK(cache_device.is_cuda());
  94. // Create data structures for the kernel.
  95. // Create an array of pointers to the key and value caches.
  96. int64_t key_cache_ptrs[num_layers];
  97. int64_t value_cache_ptrs[num_layers];
  98. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  99. key_cache_ptrs[layer_idx] =
  100. reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  101. value_cache_ptrs[layer_idx] =
  102. reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  103. }
  104. // block_mapping is a 2D tensor with shape (num_pairs, 2).
  105. int num_pairs = block_mapping.size(0);
  106. // Move the data structures to the GPU.
  107. // NOTE: This synchronizes the CPU and GPU.
  108. torch::Tensor key_cache_ptrs_tensor =
  109. torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
  110. .to(cache_device);
  111. torch::Tensor value_cache_ptrs_tensor =
  112. torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
  113. .to(cache_device);
  114. // Launch the kernel.
  115. const int numel_per_block = key_caches[0][0].numel();
  116. dim3 grid(num_layers, num_pairs);
  117. dim3 block(std::min(1024, numel_per_block));
  118. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  119. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  120. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  121. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  122. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  123. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  124. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  125. block_mapping.data_ptr<int64_t>(), numel_per_block);
  126. }));
  127. }
  128. namespace aphrodite {
  129. template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
  130. __global__ void reshape_and_cache_kernel(
  131. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  132. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  133. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
  134. // block_size, x]
  135. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
  136. // block_size]
  137. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  138. const int key_stride, const int value_stride, const int num_heads,
  139. const int head_size, const int block_size, const int x,
  140. const float kv_scale) {
  141. const int64_t token_idx = blockIdx.x;
  142. const int64_t slot_idx = slot_mapping[token_idx];
  143. if (slot_idx < 0) {
  144. // Padding token that should be ignored.
  145. return;
  146. }
  147. const int64_t block_idx = slot_idx / block_size;
  148. const int64_t block_offset = slot_idx % block_size;
  149. const int n = num_heads * head_size;
  150. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  151. const int64_t src_key_idx = token_idx * key_stride + i;
  152. const int64_t src_value_idx = token_idx * value_stride + i;
  153. const int head_idx = i / head_size;
  154. const int head_offset = i % head_size;
  155. const int x_idx = head_offset / x;
  156. const int x_offset = head_offset % x;
  157. const int64_t tgt_key_idx =
  158. block_idx * num_heads * (head_size / x) * block_size * x +
  159. head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
  160. block_offset * x + x_offset;
  161. const int64_t tgt_value_idx =
  162. block_idx * num_heads * head_size * block_size +
  163. head_idx * head_size * block_size + head_offset * block_size +
  164. block_offset;
  165. scalar_t tgt_key = key[src_key_idx];
  166. scalar_t tgt_value = value[src_value_idx];
  167. if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  168. key_cache[tgt_key_idx] = tgt_key;
  169. value_cache[tgt_value_idx] = tgt_value;
  170. } else {
  171. key_cache[tgt_key_idx] =
  172. fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
  173. value_cache[tgt_value_idx] =
  174. fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
  175. }
  176. }
  177. }
  178. template <typename scalar_t>
  179. __global__ void reshape_and_cache_flash_kernel(
  180. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  181. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  182. scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
  183. // head_size]
  184. scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
  185. // head_size]
  186. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  187. const int block_stride, const int key_stride, const int value_stride,
  188. const int num_heads, const int head_size, const int block_size) {
  189. const int64_t token_idx = blockIdx.x;
  190. const int64_t slot_idx = slot_mapping[token_idx];
  191. // NOTE: slot_idx can be -1 if the token is padded
  192. if (slot_idx < 0) {
  193. return;
  194. }
  195. const int64_t block_idx = slot_idx / block_size;
  196. const int64_t block_offset = slot_idx % block_size;
  197. const int n = num_heads * head_size;
  198. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  199. const int64_t src_key_idx = token_idx * key_stride + i;
  200. const int64_t src_value_idx = token_idx * value_stride + i;
  201. const int head_idx = i / head_size;
  202. const int head_offset = i % head_size;
  203. const int64_t tgt_value_idx = block_idx * block_stride +
  204. block_offset * num_heads * head_size +
  205. head_idx * head_size + head_offset;
  206. k_cache[tgt_value_idx] = key[src_key_idx];
  207. v_cache[tgt_value_idx] = value[src_value_idx];
  208. }
  209. }
  210. } // namespace aphrodite
  211. // KV_T is the stored data type of kv-cache.
  212. // CACHE_T is the data type of key and value tensors.
  213. // KV_DTYPE is the real data type of kv-cache.
  214. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
  215. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
  216. <<<grid, block, 0, stream>>>( \
  217. reinterpret_cast<KV_T*>(key.data_ptr()), \
  218. reinterpret_cast<KV_T*>(value.data_ptr()), \
  219. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  220. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  221. slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
  222. num_heads, head_size, block_size, x, kv_scale);
  223. void reshape_and_cache(
  224. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  225. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  226. torch::Tensor&
  227. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  228. torch::Tensor&
  229. value_cache, // [num_blocks, num_heads, head_size, block_size]
  230. torch::Tensor& slot_mapping, // [num_tokens]
  231. const std::string& kv_cache_dtype, const float kv_scale) {
  232. int num_tokens = key.size(0);
  233. int num_heads = key.size(1);
  234. int head_size = key.size(2);
  235. int block_size = key_cache.size(3);
  236. int x = key_cache.size(4);
  237. int key_stride = key.stride(0);
  238. int value_stride = value.stride(0);
  239. dim3 grid(num_tokens);
  240. dim3 block(std::min(num_heads * head_size, 512));
  241. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  242. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  243. DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
  244. CALL_RESHAPE_AND_CACHE)
  245. }
  246. void reshape_and_cache_flash(
  247. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  248. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  249. torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
  250. torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
  251. torch::Tensor& slot_mapping, // [num_tokens]
  252. const std::string& kv_cache_dtype) {
  253. // FIXME: only support auto datatype, does not support fp8
  254. if (kv_cache_dtype != "auto") {
  255. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  256. }
  257. int num_tokens = key.size(0);
  258. int num_heads = key.size(1);
  259. int head_size = key.size(2);
  260. int block_size = k_cache.size(1);
  261. int key_stride = key.stride(0);
  262. int value_stride = value.stride(0);
  263. int block_stride = k_cache.stride(0);
  264. TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
  265. dim3 grid(num_tokens);
  266. dim3 block(std::min(num_heads * head_size, 512));
  267. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  268. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  269. APHRODITE_DISPATCH_FLOATING_TYPES(
  270. key.scalar_type(), "reshape_and_cache_flash", [&] {
  271. aphrodite::reshape_and_cache_flash_kernel<scalar_t>
  272. <<<grid, block, 0, stream>>>(
  273. key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
  274. k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
  275. slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
  276. value_stride, num_heads, head_size, block_size);
  277. });
  278. }
  279. namespace aphrodite {
  280. template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  281. __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
  282. Tout* __restrict__ dst_cache,
  283. const float kv_scale,
  284. const int64_t block_stride) {
  285. const int64_t block_idx = blockIdx.x;
  286. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  287. int64_t idx = block_idx * block_stride + i;
  288. dst_cache[idx] =
  289. fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
  290. }
  291. }
  292. } // namespace aphrodite
  293. #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
  294. aphrodite::convert_fp8_kernel<Tout, Tin, KV_DTYPE> \
  295. <<<grid, block, 0, stream>>>( \
  296. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  297. reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, \
  298. block_stride);
  299. // Only for testing.
  300. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
  301. const float kv_scale, const std::string& kv_cache_dtype) {
  302. torch::Device src_device = src_cache.device();
  303. torch::Device dst_device = dst_cache.device();
  304. TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  305. TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  306. TORCH_CHECK(src_device.index() == dst_device.index(),
  307. "src and dst must be on the same GPU");
  308. at::cuda::OptionalCUDAGuard device_guard(src_device);
  309. int64_t num_blocks = src_cache.size(0);
  310. int64_t block_stride = src_cache.stride(0);
  311. dim3 grid(num_blocks);
  312. dim3 block(std::min(block_stride, int64_t(512)));
  313. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  314. if (kv_cache_dtype == "auto") {
  315. if (src_cache.dtype() == at::ScalarType::Float) {
  316. CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kAuto);
  317. } else if (src_cache.dtype() == at::ScalarType::Half) {
  318. CALL_CONVERT_FP8(uint8_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto);
  319. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  320. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
  321. aphrodite::Fp8KVCacheDataType::kAuto);
  322. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  323. CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  324. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  325. CALL_CONVERT_FP8(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  326. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  327. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
  328. aphrodite::Fp8KVCacheDataType::kAuto);
  329. }
  330. } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
  331. if (src_cache.dtype() == at::ScalarType::Float) {
  332. CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  333. } else if (src_cache.dtype() == at::ScalarType::Half) {
  334. CALL_CONVERT_FP8(uint8_t, uint16_t,
  335. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  336. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  337. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
  338. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  339. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  340. CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  341. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  342. CALL_CONVERT_FP8(uint16_t, uint8_t,
  343. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  344. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  345. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
  346. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  347. }
  348. } else {
  349. TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
  350. }
  351. }