cache_kernels.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. #ifdef USE_ROCM
  7. #include "quantization/fp8/amd/quant_utils.cuh"
  8. #else
  9. #include "quantization/fp8/nvidia/quant_utils.cuh"
  10. #endif
  11. #include <algorithm>
  12. #include <cassert>
  13. #include <map>
  14. #include <vector>
  15. #ifdef USE_ROCM
  16. #include <hip/hip_bf16.h>
  17. typedef __hip_bfloat16 __nv_bfloat16;
  18. #endif
  19. void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
  20. const torch::Tensor& block_mapping) {
  21. torch::Device src_device = src.device();
  22. torch::Device dst_device = dst.device();
  23. cudaMemcpyKind memcpy_type;
  24. if (src_device.is_cuda() && dst_device.is_cuda()) {
  25. TORCH_CHECK(src_device.index() == dst_device.index(),
  26. "src and dst must be on the same GPU");
  27. memcpy_type = cudaMemcpyDeviceToDevice;
  28. } else if (src_device.is_cuda() && dst_device.is_cpu()) {
  29. memcpy_type = cudaMemcpyDeviceToHost;
  30. } else if (src_device.is_cpu() && dst_device.is_cuda()) {
  31. memcpy_type = cudaMemcpyHostToDevice;
  32. } else {
  33. TORCH_CHECK(false, "Invalid device combination");
  34. }
  35. // NOTE: keep in mind that `block_mapping` should be
  36. // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  37. // synchronization.
  38. TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
  39. char* src_ptr = static_cast<char*>(src.data_ptr());
  40. char* dst_ptr = static_cast<char*>(dst.data_ptr());
  41. const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  42. const at::cuda::OptionalCUDAGuard device_guard(
  43. src_device.is_cuda() ? src_device : dst_device);
  44. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  45. // NOTE: This can be slow if the number of blocks is large.
  46. const int64_t num_blocks = block_mapping.size(0);
  47. for (size_t i = 0; i < num_blocks; i++) {
  48. int64_t src_block_number = block_mapping[i][0].item<int64_t>();
  49. int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
  50. int64_t src_offset = src_block_number * block_size_in_bytes;
  51. int64_t dst_offset = dst_block_number * block_size_in_bytes;
  52. cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
  53. block_size_in_bytes, memcpy_type, stream);
  54. }
  55. }
  56. namespace aphrodite {
  57. // Grid: (num_layers, num_pairs)
  58. template <typename scalar_t>
  59. __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
  60. int64_t* value_cache_ptrs,
  61. const int64_t* __restrict__ block_mapping,
  62. const int numel_per_block) {
  63. const int layer_idx = blockIdx.x;
  64. const int pair_idx = blockIdx.y;
  65. scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  66. scalar_t* value_cache =
  67. reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  68. int64_t src_block_number = block_mapping[2 * pair_idx];
  69. int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
  70. const int64_t src_block_offset = src_block_number * numel_per_block;
  71. const int64_t dst_block_offset = dst_block_number * numel_per_block;
  72. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  73. int64_t src_offset = src_block_offset + i;
  74. int64_t dst_offset = dst_block_offset + i;
  75. key_cache[dst_offset] = key_cache[src_offset];
  76. }
  77. for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
  78. int64_t src_offset = src_block_offset + i;
  79. int64_t dst_offset = dst_block_offset + i;
  80. value_cache[dst_offset] = value_cache[src_offset];
  81. }
  82. }
  83. } // namespace aphrodite
  84. // Note: the key_caches and value_caches vectors are constant but
  85. // not the Tensors they contain. The vectors need to be const refs
  86. // in order to satisfy pytorch's C++ operator registration code.
  87. void copy_blocks(std::vector<torch::Tensor> const& key_caches,
  88. std::vector<torch::Tensor> const& value_caches,
  89. const torch::Tensor& block_mapping) {
  90. int num_layers = key_caches.size();
  91. TORCH_CHECK(num_layers == value_caches.size());
  92. if (num_layers == 0) {
  93. return;
  94. }
  95. torch::Device cache_device = key_caches[0].device();
  96. TORCH_CHECK(cache_device.is_cuda());
  97. // Create data structures for the kernel.
  98. // Create an array of pointers to the key and value caches.
  99. std::vector<int64_t> key_cache_ptrs(num_layers);
  100. std::vector<int64_t> value_cache_ptrs(num_layers);
  101. for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
  102. key_cache_ptrs[layer_idx] =
  103. reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
  104. value_cache_ptrs[layer_idx] =
  105. reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  106. }
  107. // block_mapping is a 2D tensor with shape (num_pairs, 2).
  108. int num_pairs = block_mapping.size(0);
  109. // Move the data structures to the GPU.
  110. // NOTE: This synchronizes the CPU and GPU.
  111. torch::Tensor key_cache_ptrs_tensor =
  112. torch::from_blob(key_cache_ptrs.data(), {num_layers}, torch::kInt64)
  113. .to(cache_device);
  114. torch::Tensor value_cache_ptrs_tensor =
  115. torch::from_blob(value_cache_ptrs.data(), {num_layers}, torch::kInt64)
  116. .to(cache_device);
  117. // Launch the kernel.
  118. const int numel_per_block = key_caches[0][0].numel();
  119. dim3 grid(num_layers, num_pairs);
  120. dim3 block(std::min(1024, numel_per_block));
  121. const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  122. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  123. APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
  124. key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
  125. aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
  126. key_cache_ptrs_tensor.data_ptr<int64_t>(),
  127. value_cache_ptrs_tensor.data_ptr<int64_t>(),
  128. block_mapping.data_ptr<int64_t>(), numel_per_block);
  129. }));
  130. }
  131. namespace aphrodite {
  132. template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
  133. __global__ void reshape_and_cache_kernel(
  134. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  135. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  136. cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
  137. // block_size, x]
  138. cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
  139. // block_size]
  140. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  141. const int key_stride, const int value_stride, const int num_heads,
  142. const int head_size, const int block_size, const int x, const float k_scale,
  143. const float v_scale) {
  144. const int64_t token_idx = blockIdx.x;
  145. const int64_t slot_idx = slot_mapping[token_idx];
  146. if (slot_idx < 0) {
  147. // Padding token that should be ignored.
  148. return;
  149. }
  150. const int64_t block_idx = slot_idx / block_size;
  151. const int64_t block_offset = slot_idx % block_size;
  152. const int n = num_heads * head_size;
  153. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  154. const int64_t src_key_idx = token_idx * key_stride + i;
  155. const int64_t src_value_idx = token_idx * value_stride + i;
  156. const int head_idx = i / head_size;
  157. const int head_offset = i % head_size;
  158. const int x_idx = head_offset / x;
  159. const int x_offset = head_offset % x;
  160. const int64_t tgt_key_idx =
  161. block_idx * num_heads * (head_size / x) * block_size * x +
  162. head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
  163. block_offset * x + x_offset;
  164. const int64_t tgt_value_idx =
  165. block_idx * num_heads * head_size * block_size +
  166. head_idx * head_size * block_size + head_offset * block_size +
  167. block_offset;
  168. scalar_t tgt_key = key[src_key_idx];
  169. scalar_t tgt_value = value[src_value_idx];
  170. if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  171. key_cache[tgt_key_idx] = tgt_key;
  172. value_cache[tgt_value_idx] = tgt_value;
  173. } else {
  174. key_cache[tgt_key_idx] =
  175. fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
  176. value_cache[tgt_value_idx] =
  177. fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
  178. }
  179. }
  180. }
  181. template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
  182. __global__ void reshape_and_cache_flash_kernel(
  183. const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
  184. const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
  185. cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
  186. // head_size]
  187. cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
  188. // head_size]
  189. const int64_t* __restrict__ slot_mapping, // [num_tokens]
  190. const int block_stride, const int key_stride, const int value_stride,
  191. const int num_heads, const int head_size, const int block_size,
  192. const float k_scale, const float v_scale) {
  193. const int64_t token_idx = blockIdx.x;
  194. const int64_t slot_idx = slot_mapping[token_idx];
  195. // NOTE: slot_idx can be -1 if the token is padded
  196. if (slot_idx < 0) {
  197. return;
  198. }
  199. const int64_t block_idx = slot_idx / block_size;
  200. const int64_t block_offset = slot_idx % block_size;
  201. const int n = num_heads * head_size;
  202. for (int i = threadIdx.x; i < n; i += blockDim.x) {
  203. const int64_t src_key_idx = token_idx * key_stride + i;
  204. const int64_t src_value_idx = token_idx * value_stride + i;
  205. const int head_idx = i / head_size;
  206. const int head_offset = i % head_size;
  207. const int64_t tgt_key_value_idx = block_idx * block_stride +
  208. block_offset * num_heads * head_size +
  209. head_idx * head_size + head_offset;
  210. scalar_t tgt_key = key[src_key_idx];
  211. scalar_t tgt_value = value[src_value_idx];
  212. if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  213. key_cache[tgt_key_value_idx] = tgt_key;
  214. value_cache[tgt_key_value_idx] = tgt_value;
  215. } else {
  216. key_cache[tgt_key_value_idx] =
  217. fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
  218. value_cache[tgt_key_value_idx] =
  219. fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
  220. }
  221. }
  222. }
  223. } // namespace aphrodite
  224. // KV_T is the stored data type of kv-cache.
  225. // CACHE_T is the data type of key and value tensors.
  226. // KV_DTYPE is the real data type of kv-cache.
  227. #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
  228. aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
  229. <<<grid, block, 0, stream>>>( \
  230. reinterpret_cast<KV_T*>(key.data_ptr()), \
  231. reinterpret_cast<KV_T*>(value.data_ptr()), \
  232. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  233. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  234. slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
  235. num_heads, head_size, block_size, x, k_scale, v_scale);
  236. void reshape_and_cache(
  237. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  238. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  239. torch::Tensor&
  240. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  241. torch::Tensor&
  242. value_cache, // [num_blocks, num_heads, head_size, block_size]
  243. torch::Tensor& slot_mapping, // [num_tokens]
  244. const std::string& kv_cache_dtype, const double k_scale,
  245. const double v_scale) {
  246. int num_tokens = key.size(0);
  247. int num_heads = key.size(1);
  248. int head_size = key.size(2);
  249. int block_size = key_cache.size(3);
  250. int x = key_cache.size(4);
  251. int key_stride = key.stride(0);
  252. int value_stride = value.stride(0);
  253. dim3 grid(num_tokens);
  254. dim3 block(std::min(num_heads * head_size, 512));
  255. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  256. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  257. DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
  258. CALL_RESHAPE_AND_CACHE)
  259. }
  260. // KV_T is the stored data type of kv-cache.
  261. // CACHE_T is the data type of key and value tensors.
  262. // KV_DTYPE is the real data type of kv-cache.
  263. #define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
  264. aphrodite::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
  265. <<<grid, block, 0, stream>>>( \
  266. reinterpret_cast<KV_T*>(key.data_ptr()), \
  267. reinterpret_cast<KV_T*>(value.data_ptr()), \
  268. reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
  269. reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
  270. slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
  271. value_stride, num_heads, head_size, block_size, k_scale, v_scale);
  272. void reshape_and_cache_flash(
  273. torch::Tensor& key, // [num_tokens, num_heads, head_size]
  274. torch::Tensor& value, // [num_tokens, num_heads, head_size]
  275. torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
  276. torch::Tensor&
  277. value_cache, // [num_blocks, block_size, num_heads, head_size]
  278. torch::Tensor& slot_mapping, // [num_tokens]
  279. const std::string& kv_cache_dtype, const double k_scale,
  280. const double v_scale) {
  281. int num_tokens = key.size(0);
  282. int num_heads = key.size(1);
  283. int head_size = key.size(2);
  284. int block_size = key_cache.size(1);
  285. int key_stride = key.stride(0);
  286. int value_stride = value.stride(0);
  287. int block_stride = key_cache.stride(0);
  288. TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
  289. dim3 grid(num_tokens);
  290. dim3 block(std::min(num_heads * head_size, 512));
  291. const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  292. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  293. DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
  294. CALL_RESHAPE_AND_CACHE_FLASH);
  295. }
  296. namespace aphrodite {
  297. template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  298. __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
  299. Tout* __restrict__ dst_cache,
  300. const float scale,
  301. const int64_t block_stride) {
  302. const int64_t block_idx = blockIdx.x;
  303. for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
  304. int64_t idx = block_idx * block_stride + i;
  305. dst_cache[idx] =
  306. fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
  307. }
  308. }
  309. } // namespace aphrodite
  310. #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
  311. aphrodite::convert_fp8_kernel<Tout, Tin, KV_DTYPE> \
  312. <<<grid, block, 0, stream>>>( \
  313. reinterpret_cast<Tin*>(src_cache.data_ptr()), \
  314. reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
  315. // Only for testing.
  316. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
  317. const double scale, const std::string& kv_cache_dtype) {
  318. torch::Device src_device = src_cache.device();
  319. torch::Device dst_device = dst_cache.device();
  320. TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  321. TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  322. TORCH_CHECK(src_device.index() == dst_device.index(),
  323. "src and dst must be on the same GPU");
  324. at::cuda::OptionalCUDAGuard device_guard(src_device);
  325. int64_t num_blocks = src_cache.size(0);
  326. int64_t block_stride = src_cache.stride(0);
  327. dim3 grid(num_blocks);
  328. dim3 block(std::min(block_stride, int64_t(512)));
  329. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  330. if (kv_cache_dtype == "auto") {
  331. if (src_cache.dtype() == at::ScalarType::Float) {
  332. CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kAuto);
  333. } else if (src_cache.dtype() == at::ScalarType::Half) {
  334. CALL_CONVERT_FP8(uint8_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto);
  335. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  336. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
  337. aphrodite::Fp8KVCacheDataType::kAuto);
  338. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  339. CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  340. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  341. CALL_CONVERT_FP8(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
  342. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  343. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
  344. aphrodite::Fp8KVCacheDataType::kAuto);
  345. }
  346. } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
  347. if (src_cache.dtype() == at::ScalarType::Float) {
  348. CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  349. } else if (src_cache.dtype() == at::ScalarType::Half) {
  350. CALL_CONVERT_FP8(uint8_t, uint16_t,
  351. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  352. } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  353. CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
  354. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  355. } else if (dst_cache.dtype() == at::ScalarType::Float) {
  356. CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  357. } else if (dst_cache.dtype() == at::ScalarType::Half) {
  358. CALL_CONVERT_FP8(uint16_t, uint8_t,
  359. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  360. } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  361. CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
  362. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  363. }
  364. } else {
  365. TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
  366. }
  367. }