cache_kernels.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #ifdef USE_ROCM
  7. #include "quantization/fp8/amd/quant_utils.cuh"
  8. #else
  9. #include "quantization/fp8/nvidia/quant_utils.cuh"
  10. #endif
  11. #include <algorithm>
  12. #include <cassert>
  13. #include <map>
  14. #include <vector>
  15. #ifdef USE_ROCM
  16. #include <hip/hip_bf16.h>
  17. typedef __hip_bfloat16 __nv_bfloat16;
  18. #endif
  19. void swap_blocks(
  20. torch::Tensor& src,
  21. torch::Tensor& dst,
  22. const torch::Tensor& block_mapping) {
  23. torch::Device src_device = src.device();
  24. torch::Device dst_device = dst.device();
  25. cudaMemcpyKind memcpy_type;
  26. if (src_device.is_cuda() && dst_device.is_cuda()) {
  27. TORCH_CHECK(
  28. src_device.index() == dst_device.index(),
  29. "src and dst must be on the same GPU");
  30. memcpy_type = cudaMemcpyDeviceToDevice;
  31. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  32. memcpy_type = cudaMemcpyDeviceToHost;
  33. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  34. memcpy_type = cudaMemcpyHostToDevice;
  35. } else {
  36. TORCH_CHECK(false, "Invalid device combination");
  37. }
  38. // NOTE(youkaichao): keep in mind that `block_mapping` should be
  39. // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  40. // synchronization.
  41. TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
  42. char *src_ptr = static_cast<char*>(src.data_ptr());
  43. char *dst_ptr = static_cast<char*>(dst.data_ptr());
  44. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  45. const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
  46. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  47. // NOTE(woosuk): This can be slow if the number of blocks is large.
  48. const int64_t num_blocks = block_mapping.size(0);
  49. for (size_t i = 0; i < num_blocks; i++) {
  50. int64_t src_block_number = block_mapping[i][0].item<int64_t>();
  51. int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
  52. int64_t src_offset = src_block_number * block_size_in_bytes;
  53. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  54. cudaMemcpyAsync(
  55. dst_ptr + dst_offset,
  56. src_ptr + src_offset,
  57. block_size_in_bytes,
  58. memcpy_type,
  59. stream);
  60. }
  61. }
  62. namespace aphrodite {
  63. // Grid: (num_layers, num_pairs)
  64. template<typename scalar_t>
  65. __global__ void copy_blocks_kernel(
  66. int64_t* key_cache_ptrs,
  67. int64_t* value_cache_ptrs,
  68. const int64_t* __restrict__ block_mapping,
  69. const int numel_per_block) {
  70. const int layer_idx = blockIdx.x;
  71. const int pair_idx = blockIdx.y;
  72. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  73. scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  74. int64_t src_block_number = block_mapping[2 * pair_idx];
  75. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  76. const int64_t src_block_offset = src_block_number * numel_per_block;
  77. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  78. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  79. int64_t src_offset = src_block_offset + i;
  80. int64_t dst_offset = dst_block_offset + i;
  81. key_cache[dst_offset] = key_cache[src_offset];
  82. }
  83. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  84. int64_t src_offset = src_block_offset + i;
  85. int64_t dst_offset = dst_block_offset + i;
  86. value_cache[dst_offset] = value_cache[src_offset];
  87. }
  88. }
  89. } // namespace aphrodite
  90. void copy_blocks(
  91. std::vector<torch::Tensor>& key_caches,
  92. std::vector<torch::Tensor>& value_caches,
  93. const torch::Tensor& block_mapping) {
  94. int num_layers = key_caches.size();
  95. TORCH_CHECK(num_layers == value_caches.size());
  96. if (num_layers == 0) {
  97. return;
  98. }
  99. torch::Device cache_device = key_caches[0].device();
  100. TORCH_CHECK(cache_device.is_cuda());
  101. // Create data structures for the kernel.
  102. // Create an array of pointers to the key and value caches.
  103. int64_t key_cache_ptrs[num_layers];
  104. int64_t value_cache_ptrs[num_layers];
  105. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  106. key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  107. value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  108. }
  109. // block_mapping is a 2D tensor with shape (num_pairs, 2).
  110. int num_pairs = block_mapping.size(0);
  111. // Move the data structures to the GPU.
  112. // NOTE: This synchronizes the CPU and GPU.
  113. torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
  114. key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  115. torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
  116. value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  117. // Launch the kernel.
  118. const int numel_per_block = key_caches[0][0].numel();
  119. dim3 grid(num_layers, num_pairs);
  120. dim3 block(std::min(1024, numel_per_block));
  121. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  122. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  123. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  124. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  125. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  126. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  127. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  128. block_mapping.data_ptr<int64_t>(),
  129. numel_per_block);
  130. }));
  131. }
  132. namespace aphrodite {
  133. template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
  134. __global__ void reshape_and_cache_kernel(
  135. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  136. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  137. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  138. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
  139. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  140. const int key_stride,
  141. const int value_stride,
  142. const int num_heads,
  143. const int head_size,
  144. const int block_size,
  145. const int x,
  146. const float kv_scale) {
  147. const int64_t token_idx = blockIdx.x;
  148. const int64_t slot_idx = slot_mapping[token_idx];
  149. if (slot_idx < 0) {
  150. // Padding token that should be ignored.
  151. return;
  152. }
  153. const int64_t block_idx = slot_idx / block_size;
  154. const int64_t block_offset = slot_idx % block_size;
  155. const int n = num_heads * head_size;
  156. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  157. const int64_t src_key_idx = token_idx * key_stride + i;
  158. const int64_t src_value_idx = token_idx * value_stride + i;
  159. const int head_idx = i / head_size;
  160. const int head_offset = i % head_size;
  161. const int x_idx = head_offset / x;
  162. const int x_offset = head_offset % x;
  163. const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
  164. + head_idx * (head_size / x) * block_size * x
  165. + x_idx * block_size * x
  166. + block_offset * x
  167. + x_offset;
  168. const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
  169. + head_idx * head_size * block_size
  170. + head_offset * block_size
  171. + block_offset;
  172. scalar_t tgt_key = key[src_key_idx];
  173. scalar_t tgt_value = value[src_value_idx];
  174. if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  175. key_cache[tgt_key_idx] = tgt_key;
  176. value_cache[tgt_value_idx] = tgt_value;
  177. } else {
  178. key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
  179. value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
  180. }
  181. }
  182. }
  183. template<typename scalar_t>
  184. __global__ void reshape_and_cache_flash_kernel(
  185. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  186. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  187. scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
  188. scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
  189. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  190. const int block_stride,
  191. const int key_stride,
  192. const int value_stride,
  193. const int num_heads,
  194. const int head_size,
  195. const int block_size) {
  196. const int64_t token_idx = blockIdx.x;
  197. const int64_t slot_idx = slot_mapping[token_idx];
  198. // NOTE: slot_idx can be -1 if the token is padded
  199. if (slot_idx < 0) {
  200. return;
  201. }
  202. const int64_t block_idx = slot_idx / block_size;
  203. const int64_t block_offset = slot_idx % block_size;
  204. const int n = num_heads * head_size;
  205. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  206. const int64_t src_key_idx = token_idx * key_stride + i;
  207. const int64_t src_value_idx = token_idx * value_stride + i;
  208. const int head_idx = i / head_size;
  209. const int head_offset = i % head_size;
  210. const int64_t tgt_value_idx = block_idx * block_stride
  211. + block_offset * num_heads * head_size
  212. + head_idx * head_size
  213. + head_offset;
  214. k_cache[tgt_value_idx] = key[src_key_idx];
  215. v_cache[tgt_value_idx] = value[src_value_idx];
  216. }
  217. }
  218. } // namespace aphrodite
  219. // KV_T is the stored data type of kv-cache.
  220. // CACHE_T is the data type of key and value tensors.
  221. // KV_DTYPE is the real data type of kv-cache.
  222. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
  223. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
  224. reinterpret_cast<KV_T*>(key.data_ptr()), \
  225. reinterpret_cast<KV_T*>(value.data_ptr()), \
  226. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  227. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  228. slot_mapping.data_ptr<int64_t>(), \
  229. key_stride, \
  230. value_stride, \
  231. num_heads, \
  232. head_size, \
  233. block_size, \
  234. x, \
  235. kv_scale);
  236. void reshape_and_cache(
  237. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  238. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  239. torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  240. torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
  241. torch::Tensor& slot_mapping, // [num_tokens]
  242. const std::string& kv_cache_dtype,
  243. const float kv_scale)
  244. {
  245. int num_tokens = key.size(0);
  246. int num_heads = key.size(1);
  247. int head_size = key.size(2);
  248. int block_size = key_cache.size(3);
  249. int x = key_cache.size(4);
  250. int key_stride = key.stride(0);
  251. int value_stride = value.stride(0);
  252. dim3 grid(num_tokens);
  253. dim3 block(std::min(num_heads * head_size, 512));
  254. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  255. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  256. DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
  257. }
  258. void reshape_and_cache_flash(
  259. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  260. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  261. torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
  262. torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
  263. torch::Tensor& slot_mapping, // [num_tokens]
  264. const std::string& kv_cache_dtype)
  265. {
  266. // FIXME: only support auto datatype, does not support fp8
  267. if (kv_cache_dtype != "auto") {
  268. TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  269. }
  270. int num_tokens = key.size(0);
  271. int num_heads = key.size(1);
  272. int head_size = key.size(2);
  273. int block_size = k_cache.size(1);
  274. int key_stride = key.stride(0);
  275. int value_stride = value.stride(0);
  276. int block_stride = k_cache.stride(0);
  277. TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
  278. dim3 grid(num_tokens);
  279. dim3 block(std::min(num_heads * head_size, 512));
  280. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  281. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  282. APHRODITE_DISPATCH_FLOATING_TYPES(
  283. key.scalar_type(),
  284. "reshape_and_cache_flash",
  285. [&] {
  286. aphrodite::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
  287. key.data_ptr<scalar_t>(),
  288. value.data_ptr<scalar_t>(),
  289. k_cache.data_ptr<scalar_t>(),
  290. v_cache.data_ptr<scalar_t>(),
  291. slot_mapping.data_ptr<int64_t>(),
  292. block_stride,
  293. key_stride,
  294. value_stride,
  295. num_heads,
  296. head_size,
  297. block_size);
  298. });
  299. }
  300. namespace aphrodite {
  301. template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  302. __global__ void convert_fp8_kernel(
  303. const Tin* __restrict__ src_cache,
  304. Tout* __restrict__ dst_cache,
  305. const float kv_scale,
  306. const int64_t block_stride) {
  307. const int64_t block_idx = blockIdx.x;
  308. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  309. int64_t idx = block_idx * block_stride + i;
  310. dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
  311. }
  312. }
  313. } // namespace aphrodite
  314. #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
  315. aphrodite::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
  316. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  317. reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
  318. kv_scale, \
  319. block_stride);
  320. // Only for testing.
  321. void convert_fp8(
  322. torch::Tensor& dst_cache,
  323. torch::Tensor& src_cache,
  324. const float kv_scale,
  325. const std::string& kv_cache_dtype)
  326. {
  327. torch::Device src_device = src_cache.device();
  328. torch::Device dst_device = dst_cache.device();
  329. TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  330. TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  331. TORCH_CHECK(
  332. src_device.index() == dst_device.index(),
  333. "src and dst must be on the same GPU");
  334. at::cuda::OptionalCUDAGuard device_guard(src_device);
  335. int64_t num_blocks = src_cache.size(0);
  336. int64_t block_stride = src_cache.stride(0);
  337. dim3 grid(num_blocks);
  338. dim3 block(std::min(block_stride, int64_t(512)));
  339. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  340. if (kv_cache_dtype == "auto") {
  341. if (src_cache.dtype() == at::ScalarType::Float) {
  342. CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kAuto);
  343. } else if (src_cache.dtype() == at::ScalarType::Half) {
  344. CALL_CONVERT_FP8(uint8_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto);
  345. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  346. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, aphrodite::Fp8KVCacheDataType::kAuto);
  347. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  348. CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  349. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  350. CALL_CONVERT_FP8(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  351. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  352. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  353. }
  354. } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
  355. if (src_cache.dtype() == at::ScalarType::Float) {
  356. CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  357. } else if (src_cache.dtype() == at::ScalarType::Half) {
  358. CALL_CONVERT_FP8(uint8_t, uint16_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  359. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  360. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  361. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  362. CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  363. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  364. CALL_CONVERT_FP8(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  365. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  366. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  367. }
  368. } else {
  369. TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
  370. }
  371. }