#include #include #include #include "cuda_compat.h" #include "dispatch_utils.h" #include "quantization/int8_kvcache/quant_utils.cuh" #ifdef ENABLE_FP8_E5M2 #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" #endif #include #include #include #include enum kv_cache_dtype { AUTO, #ifdef ENABLE_FP8_E5M2 FP8_E5M2, #endif INT8}; #ifdef USE_ROCM #include typedef __hip_bfloat16 __nv_bfloat16; #endif void swap_blocks( torch::Tensor& src, torch::Tensor& dst, const std::map& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { TORCH_CHECK( src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; } else if (src_device.is_cpu() && dst_device.is_cuda()) { memcpy_type = cudaMemcpyHostToDevice; } else { TORCH_CHECK(false, "Invalid device combination"); } char *src_ptr = static_cast(src.data_ptr()); char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE: This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) { int64_t src_block_number = pair.first; int64_t dst_block_number = pair.second; int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; cudaMemcpyAsync( dst_ptr + dst_offset, src_ptr + src_offset, block_size_in_bytes, memcpy_type, stream); } } namespace aphrodite { // Grid: (num_layers, num_pairs) template __global__ void copy_blocks_kernel( int64_t* key_cache_ptrs, int64_t* value_cache_ptrs, const int64_t* __restrict__ block_mapping, const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; const int64_t src_block_offset = src_block_number * numel_per_block; const int64_t dst_block_offset = dst_block_number * numel_per_block; for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { int64_t src_offset = src_block_offset + i; int64_t dst_offset = dst_block_offset + i; key_cache[dst_offset] = key_cache[src_offset]; } for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { int64_t src_offset = src_block_offset + i; int64_t dst_offset = dst_block_offset + i; value_cache[dst_offset] = value_cache[src_offset]; } } } // namespace aphrodite void copy_blocks( std::vector& key_caches, std::vector& value_caches, const std::map>& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; } torch::Device cache_device = key_caches[0].device(); TORCH_CHECK(cache_device.is_cuda()); // Create data structures for the kernel. // Create an array of pointers to the key and value caches. int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } // Create block mapping array. std::vector block_mapping_vec; for (const auto& pair : block_mapping) { int64_t src_block_number = pair.first; for (int64_t dst_block_number : pair.second) { block_mapping_vec.push_back(src_block_number); block_mapping_vec.push_back(dst_block_number); } } int64_t* block_mapping_array = block_mapping_vec.data(); int num_pairs = block_mapping_vec.size() / 2; // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. torch::Tensor key_cache_ptrs_tensor = torch::from_blob( key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor = torch::from_blob( value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor block_mapping_tensor = torch::from_blob( block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { aphrodite::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), value_cache_ptrs_tensor.data_ptr(), block_mapping_tensor.data_ptr(), numel_per_block); })); } namespace aphrodite { template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x, const float k_scale, const float k_zp, const float v_scale, const float v_zp) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { // Padding token that should be ignored. return; } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { const int64_t src_key_idx = token_idx * key_stride + i; const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; const int x_idx = head_offset / x; const int x_offset = head_offset % x; const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset; const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (KV_CACHE_DTYPE == INT8) { key_cache[tgt_key_idx] = int8::quant(tgt_key, k_scale, k_zp); value_cache[tgt_value_idx] = int8::quant(tgt_value, v_scale, v_zp); #ifdef ENABLE_FP8_E5M2 } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) { key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); #endif } else { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; } } } } // namespace aphrodite #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_CACHE_DTYPE) \ aphrodite::reshape_and_cache_kernel<<>>( \ reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), \ key_stride, \ value_stride, \ num_heads, \ head_size, \ block_size, \ x, \ k_scale, \ k_zp, \ v_scale, \ v_zp); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype, const float k_scale = 1.0f, const float k_zp = 0.0f, const float v_scale = 1.0f, const float v_zp = 0.0f) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3); int x = key_cache.size(4); int key_stride = key.stride(0); int value_stride = value.stride(0); dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (kv_cache_dtype == "auto") { if (key.dtype() == at::ScalarType::Float) { CALL_RESHAPE_AND_CACHE(float, float, AUTO); } else if (key.dtype() == at::ScalarType::Half) { CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, AUTO); } else if (key.dtype() == at::ScalarType::BFloat16) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, AUTO); } #ifdef ENABLE_FP8_E5M2 } else if (kv_cache_dtype == "fp8_e5m2") { if (key.dtype() == at::ScalarType::Float) { CALL_RESHAPE_AND_CACHE(float, uint8_t, FP8_E5M2); } else if (key.dtype() == at::ScalarType::Half) { CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, FP8_E5M2); } else if (key.dtype() == at::ScalarType::BFloat16) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, FP8_E5M2); } #endif } else if (kv_cache_dtype == "int8") { if (key.dtype() == at::ScalarType::Float) { CALL_RESHAPE_AND_CACHE(float, int8_t, INT8); } else if (key.dtype() == at::ScalarType::Half) { CALL_RESHAPE_AND_CACHE(uint16_t, int8_t, INT8); } else if (key.dtype() == at::ScalarType::BFloat16) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16, int8_t, INT8); } } else { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } } namespace aphrodite { // Grid: (num_blocks, block_size). template __global__ void gather_cached_kv_kernel( scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int token_idx = blockIdx.x; const int slot_idx = slot_mapping[token_idx]; const int block_idx = slot_idx / block_size; const int block_offset = slot_idx % block_size; const int num_tokens = num_heads * head_size; for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { const int tgt_key_idx = token_idx * key_stride + i; const int tgt_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension const int x_offset = head_offset % x; const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset; const int src_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset; key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]); value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]); } } template __global__ void gather_cached_kv_kernel_optimized( scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size] scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size] const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int *__restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int token_idx = blockIdx.x; const int slot_idx = slot_mapping[token_idx]; const int block_idx = slot_idx / block_size; const int block_offset = slot_idx % block_size; const int dim = num_heads * head_size; assert(dim % 4 == 0); // this is true for known use cases const int unroll_factor = 4; const int unrolled_dim = dim / unroll_factor; for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x) { int tgt_key_indices[unroll_factor]; int tgt_value_indices[unroll_factor]; int src_key_indices[unroll_factor]; int src_value_indices[unroll_factor]; scalar_t keys_to_store[unroll_factor]; scalar_t values_to_store[unroll_factor]; #pragma unroll for (int j = 0; j < unroll_factor; ++j) { int index = i + j * unrolled_dim; const int tgt_key_idx = token_idx * key_stride + index; const int tgt_value_idx = token_idx * value_stride + index; const int head_idx = index / head_size; const int head_offset = index % head_size; const int x_idx = head_offset / x; const int x_offset = head_offset % x; const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset; const int src_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset; tgt_key_indices[j] = tgt_key_idx; tgt_value_indices[j] = tgt_value_idx; src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]); values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]); } #pragma unroll for (int j = 0; j < unroll_factor; ++j) { key[tgt_key_indices[j]] = keys_to_store[j]; value[tgt_value_indices[j]] = values_to_store[j]; } } } } // namespace aphrodite void gather_cached_kv( torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping) // [in] [num_tokens] { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3); int x = key_cache.size(4); int key_stride = key.stride(0); int value_stride = value.stride(0); dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES( key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { aphrodite::gather_cached_kv_kernel_optimized<<>>( key.data_ptr(), value.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), slot_mapping.data_ptr(), key_stride, value_stride, num_heads, head_size, block_size, x); }); } namespace aphrodite { template __global__ void convert_fp8_e5m2_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; #ifdef ENABLE_FP8_E5M2 dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); #else assert(false); #endif } } } // namespace aphrodite #define CALL_CONVERT_FP8_E5M2(Tout, Tin) \ aphrodite::convert_fp8_e5m2_kernel<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ reinterpret_cast(dst_cache.data_ptr()), \ block_stride); void convert_fp8_e5m2( torch::Tensor& src_cache, torch::Tensor& dst_cache) { int64_t num_blocks = src_cache.size(0); int64_t block_stride = src_cache.stride(0); dim3 grid(num_blocks); dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (src_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8_E5M2(uint8_t, float); } else if (src_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); } else if (src_cache.dtype() == at::ScalarType::BFloat16) { CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); } else if (dst_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8_E5M2(float, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); } }