cache_kernels.cu 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #if defined(ENABLE_FP8_E5M2)
  7. #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
  8. #elif defined(ENABLE_FP8_E4M3)
  9. #include "quantization/fp8/amd_detail/quant_utils.cuh"
  10. #endif
  11. #include <algorithm>
  12. #include <cassert>
  13. #include <map>
  14. #include <vector>
  15. #ifdef USE_ROCM
  16. #include <hip/hip_bf16.h>
  17. typedef __hip_bfloat16 __nv_bfloat16;
  18. #endif
  19. void swap_blocks(
  20. torch::Tensor& src,
  21. torch::Tensor& dst,
  22. const std::map<int64_t, int64_t>& block_mapping) {
  23. torch::Device src_device = src.device();
  24. torch::Device dst_device = dst.device();
  25. cudaMemcpyKind memcpy_type;
  26. if (src_device.is_cuda() && dst_device.is_cuda()) {
  27. TORCH_CHECK(
  28. src_device.index() == dst_device.index(),
  29. "src and dst must be on the same GPU");
  30. memcpy_type = cudaMemcpyDeviceToDevice;
  31. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  32. memcpy_type = cudaMemcpyDeviceToHost;
  33. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  34. memcpy_type = cudaMemcpyHostToDevice;
  35. } else {
  36. TORCH_CHECK(false, "Invalid device combination");
  37. }
  38. char *src_ptr = static_cast<char*>(src.data_ptr());
  39. char *dst_ptr = static_cast<char*>(dst.data_ptr());
  40. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  41. const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
  42. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  43. // NOTE: This can be slow if the number of blocks is large.
  44. for (const auto& pair : block_mapping) {
  45. int64_t src_block_number = pair.first;
  46. int64_t dst_block_number = pair.second;
  47. int64_t src_offset = src_block_number * block_size_in_bytes;
  48. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  49. cudaMemcpyAsync(
  50. dst_ptr + dst_offset,
  51. src_ptr + src_offset,
  52. block_size_in_bytes,
  53. memcpy_type,
  54. stream);
  55. }
  56. }
  57. namespace aphrodite {
  58. // Grid: (num_layers, num_pairs)
  59. template<typename scalar_t>
  60. __global__ void copy_blocks_kernel(
  61. int64_t* key_cache_ptrs,
  62. int64_t* value_cache_ptrs,
  63. const int64_t* __restrict__ block_mapping,
  64. const int numel_per_block) {
  65. const int layer_idx = blockIdx.x;
  66. const int pair_idx = blockIdx.y;
  67. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  68. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  69. int64_t src_block_number = block_mapping[2 * pair_idx];
  70. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  71. const int64_t src_block_offset = src_block_number * numel_per_block;
  72. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  73. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  74. int64_t src_offset = src_block_offset + i;
  75. int64_t dst_offset = dst_block_offset + i;
  76. key_cache[dst_offset] = key_cache[src_offset];
  77. }
  78. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  79. int64_t src_offset = src_block_offset + i;
  80. int64_t dst_offset = dst_block_offset + i;
  81. value_cache[dst_offset] = value_cache[src_offset];
  82. }
  83. }
  84. } // namespace aphrodite
  85. void copy_blocks(
  86. std::vector<torch::Tensor>& key_caches,
  87. std::vector<torch::Tensor>& value_caches,
  88. const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  89. int num_layers = key_caches.size();
  90. TORCH_CHECK(num_layers == value_caches.size());
  91. if (num_layers == 0) {
  92. return;
  93. }
  94. torch::Device cache_device = key_caches[0].device();
  95. TORCH_CHECK(cache_device.is_cuda());
  96. // Create data structures for the kernel.
  97. // Create an array of pointers to the key and value caches.
  98. int64_t key_cache_ptrs[num_layers];
  99. int64_t value_cache_ptrs[num_layers];
  100. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  101. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  102. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  103. }
  104. // Create block mapping array.
  105. std::vector<int64_t> block_mapping_vec;
  106. for (const auto& pair : block_mapping) {
  107. int64_t src_block_number = pair.first;
  108. for (int64_t dst_block_number : pair.second) {
  109. block_mapping_vec.push_back(src_block_number);
  110. block_mapping_vec.push_back(dst_block_number);
  111. }
  112. }
  113. int64_t* block_mapping_array = block_mapping_vec.data();
  114. int num_pairs = block_mapping_vec.size() / 2;
  115. // Move the data structures to the GPU.
  116. // NOTE: This synchronizes the CPU and GPU.
  117. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  118. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  119. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  120. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  121. torch::Tensor block_mapping_tensor = torch::from_blob(
  122. block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
  123. // Launch the kernel.
  124. const int numel_per_block = key_caches[0][0].numel();
  125. dim3 grid(num_layers, num_pairs);
  126. dim3 block(std::min(1024, numel_per_block));
  127. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  128. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  129. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  130. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  131. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  132. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  133. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  134. block_mapping_tensor.data_ptr<int64_t>(),
  135. numel_per_block);
  136. }));
  137. }
  138. namespace aphrodite {
  139. template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
  140. __global__ void reshape_and_cache_kernel(
  141. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  142. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  143. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  144. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  145. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  146. const int key_stride,
  147. const int value_stride,
  148. const int num_heads,
  149. const int head_size,
  150. const int block_size,
  151. const int x,
  152. const float kv_scale) {
  153. const int64_t token_idx = blockIdx.x;
  154. const int64_t slot_idx = slot_mapping[token_idx];
  155. if (slot_idx < 0) {
  156. // Padding token that should be ignored.
  157. return;
  158. }
  159. const int64_t block_idx = slot_idx / block_size;
  160. const int64_t block_offset = slot_idx % block_size;
  161. const int n = num_heads * head_size;
  162. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  163. const int64_t src_key_idx = token_idx * key_stride + i;
  164. const int64_t src_value_idx = token_idx * value_stride + i;
  165. const int head_idx = i / head_size;
  166. const int head_offset = i % head_size;
  167. const int x_idx = head_offset / x;
  168. const int x_offset = head_offset % x;
  169. const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  170. + head_idx * (head_size / x) * block_size * x
  171. + x_idx * block_size * x
  172. + block_offset * x
  173. + x_offset;
  174. const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
  175. + head_idx * head_size * block_size
  176. + head_offset * block_size
  177. + block_offset;
  178. scalar_t tgt_key = key[src_key_idx];
  179. scalar_t tgt_value = value[src_value_idx];
  180. if constexpr (is_fp8_kv_cache) {
  181. #if defined(ENABLE_FP8_E5M2)
  182. key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
  183. value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
  184. #elif defined(ENABLE_FP8_E4M3)
  185. key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
  186. value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
  187. #else
  188. assert(false);
  189. #endif
  190. } else {
  191. key_cache[tgt_key_idx] = tgt_key;
  192. value_cache[tgt_value_idx] = tgt_value;
  193. }
  194. }
  195. }
  196. } // namespace aphrodite
  197. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
  198. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
  199. reinterpret_cast<KV_T*>(key.data_ptr()), \
  200. reinterpret_cast<KV_T*>(value.data_ptr()), \
  201. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  202. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  203. slot_mapping.data_ptr<int64_t>(), \
  204. key_stride, \
  205. value_stride, \
  206. num_heads, \
  207. head_size, \
  208. block_size, \
  209. x, \
  210. kv_scale);
  211. void reshape_and_cache(
  212. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  213. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  214. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  215. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  216. torch::Tensor& slot_mapping, // [num_tokens]
  217. const std::string& kv_cache_dtype,
  218. const float kv_scale)
  219. {
  220. int num_tokens = key.size(0);
  221. int num_heads = key.size(1);
  222. int head_size = key.size(2);
  223. int block_size = key_cache.size(3);
  224. int x = key_cache.size(4);
  225. int key_stride = key.stride(0);
  226. int value_stride = value.stride(0);
  227. dim3 grid(num_tokens);
  228. dim3 block(std::min(num_heads * head_size, 512));
  229. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  230. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  231. if (kv_cache_dtype == "auto") {
  232. if (key.dtype() == at::ScalarType::Float) {
  233. CALL_RESHAPE_AND_CACHE(float, float, false);
  234. } else if (key.dtype() == at::ScalarType::Half) {
  235. CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
  236. } else if (key.dtype() == at::ScalarType::BFloat16) {
  237. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
  238. }
  239. } else if (kv_cache_dtype == "fp8") {
  240. if (key.dtype() == at::ScalarType::Float) {
  241. CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
  242. } else if (key.dtype() == at::ScalarType::Half) {
  243. CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
  244. } else if (key.dtype() == at::ScalarType::BFloat16) {
  245. CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
  246. }
  247. } else {
  248. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  249. }
  250. }
  251. namespace aphrodite {
  252. template<typename Tout, typename Tin>
  253. __global__ void convert_fp8_kernel(
  254. const Tin* __restrict__ src_cache,
  255. Tout* __restrict__ dst_cache,
  256. const int64_t block_stride) {
  257. const int64_t block_idx = blockIdx.x;
  258. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  259. int64_t idx = block_idx * block_stride + i;
  260. #if defined(ENABLE_FP8_E5M2)
  261. dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
  262. #elif defined(ENABLE_FP8_E4M3)
  263. dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
  264. #else
  265. assert(false);
  266. #endif
  267. }
  268. }
  269. } // namespace aphrodite
  270. #define CALL_CONVERT_FP8(Tout, Tin) \
  271. aphrodite::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
  272. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  273. reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
  274. block_stride);
  275. void convert_fp8(
  276. torch::Tensor& src_cache,
  277. torch::Tensor& dst_cache)
  278. {
  279. torch::Device src_device = src_cache.device();
  280. torch::Device dst_device = dst_cache.device();
  281. TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  282. TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  283. TORCH_CHECK(
  284. src_device.index() == dst_device.index(),
  285. "src and dst must be on the same GPU");
  286. at::cuda::OptionalCUDAGuard device_guard(src_device);
  287. int64_t num_blocks = src_cache.size(0);
  288. int64_t block_stride = src_cache.stride(0);
  289. dim3 grid(num_blocks);
  290. dim3 block(std::min(block_stride, int64_t(512)));
  291. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  292. if (src_cache.dtype() == at::ScalarType::Float) {
  293. CALL_CONVERT_FP8(uint8_t, float);
  294. } else if (src_cache.dtype() == at::ScalarType::Half) {
  295. CALL_CONVERT_FP8(uint8_t, uint16_t);
  296. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  297. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
  298. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  299. CALL_CONVERT_FP8(float, uint8_t);
  300. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  301. CALL_CONVERT_FP8(uint16_t, uint8_t);
  302. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  303. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
  304. }
  305. }