123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- #include <torch/all.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include "cuda_compat.h"
- #include "dispatch_utils.h"
- #ifdef USE_ROCM
- #include "quantization/fp8/amd/quant_utils.cuh"
- #else
- #include "quantization/fp8/nvidia/quant_utils.cuh"
- #endif
- #include <algorithm>
- #include <cassert>
- #include <map>
- #include <vector>
- #ifdef USE_ROCM
- #include <hip/hip_bf16.h>
- typedef __hip_bfloat16 __nv_bfloat16;
- #endif
- void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
- const torch::Tensor& block_mapping) {
- torch::Device src_device = src.device();
- torch::Device dst_device = dst.device();
- cudaMemcpyKind memcpy_type;
- if (src_device.is_cuda() && dst_device.is_cuda()) {
- TORCH_CHECK(src_device.index() == dst_device.index(),
- "src and dst must be on the same GPU");
- memcpy_type = cudaMemcpyDeviceToDevice;
- } else if (src_device.is_cuda() && dst_device.is_cpu()) {
- memcpy_type = cudaMemcpyDeviceToHost;
- } else if (src_device.is_cpu() && dst_device.is_cuda()) {
- memcpy_type = cudaMemcpyHostToDevice;
- } else {
- TORCH_CHECK(false, "Invalid device combination");
- }
- // NOTE: keep in mind that `block_mapping` should be
- // a cpu tensor, otherwise every `item` call will require a gpu-cpu
- // synchronization.
- TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
- char* src_ptr = static_cast<char*>(src.data_ptr());
- char* dst_ptr = static_cast<char*>(dst.data_ptr());
- const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
- const at::cuda::OptionalCUDAGuard device_guard(
- src_device.is_cuda() ? src_device : dst_device);
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- // NOTE: This can be slow if the number of blocks is large.
- const int64_t num_blocks = block_mapping.size(0);
- for (size_t i = 0; i < num_blocks; i++) {
- int64_t src_block_number = block_mapping[i][0].item<int64_t>();
- int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
- int64_t src_offset = src_block_number * block_size_in_bytes;
- int64_t dst_offset = dst_block_number * block_size_in_bytes;
- cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
- block_size_in_bytes, memcpy_type, stream);
- }
- }
- namespace aphrodite {
- // Grid: (num_layers, num_pairs)
- template <typename scalar_t>
- __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
- int64_t* value_cache_ptrs,
- const int64_t* __restrict__ block_mapping,
- const int numel_per_block) {
- const int layer_idx = blockIdx.x;
- const int pair_idx = blockIdx.y;
- scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
- scalar_t* value_cache =
- reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
- int64_t src_block_number = block_mapping[2 * pair_idx];
- int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
- const int64_t src_block_offset = src_block_number * numel_per_block;
- const int64_t dst_block_offset = dst_block_number * numel_per_block;
- for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
- int64_t src_offset = src_block_offset + i;
- int64_t dst_offset = dst_block_offset + i;
- key_cache[dst_offset] = key_cache[src_offset];
- }
- for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
- int64_t src_offset = src_block_offset + i;
- int64_t dst_offset = dst_block_offset + i;
- value_cache[dst_offset] = value_cache[src_offset];
- }
- }
- } // namespace aphrodite
- // Note: the key_caches and value_caches vectors are constant but
- // not the Tensors they contain. The vectors need to be const refs
- // in order to satisfy pytorch's C++ operator registration code.
- void copy_blocks(std::vector<torch::Tensor> const& key_caches,
- std::vector<torch::Tensor> const& value_caches,
- const torch::Tensor& block_mapping) {
- int num_layers = key_caches.size();
- TORCH_CHECK(num_layers == value_caches.size());
- if (num_layers == 0) {
- return;
- }
- torch::Device cache_device = key_caches[0].device();
- TORCH_CHECK(cache_device.is_cuda());
- // Create data structures for the kernel.
- // Create an array of pointers to the key and value caches.
- std::vector<int64_t> key_cache_ptrs(num_layers);
- std::vector<int64_t> value_cache_ptrs(num_layers);
- for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
- key_cache_ptrs[layer_idx] =
- reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
- value_cache_ptrs[layer_idx] =
- reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
- }
- // block_mapping is a 2D tensor with shape (num_pairs, 2).
- int num_pairs = block_mapping.size(0);
- // Move the data structures to the GPU.
- // NOTE: This synchronizes the CPU and GPU.
- torch::Tensor key_cache_ptrs_tensor =
- torch::from_blob(key_cache_ptrs.data(), {num_layers}, torch::kInt64)
- .to(cache_device);
- torch::Tensor value_cache_ptrs_tensor =
- torch::from_blob(value_cache_ptrs.data(), {num_layers}, torch::kInt64)
- .to(cache_device);
- // Launch the kernel.
- const int numel_per_block = key_caches[0][0].numel();
- dim3 grid(num_layers, num_pairs);
- dim3 block(std::min(1024, numel_per_block));
- const at::cuda::OptionalCUDAGuard device_guard(cache_device);
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
- key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
- aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
- key_cache_ptrs_tensor.data_ptr<int64_t>(),
- value_cache_ptrs_tensor.data_ptr<int64_t>(),
- block_mapping.data_ptr<int64_t>(), numel_per_block);
- }));
- }
- namespace aphrodite {
- template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
- __global__ void reshape_and_cache_kernel(
- const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
- const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
- cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
- // block_size, x]
- cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
- // block_size]
- const int64_t* __restrict__ slot_mapping, // [num_tokens]
- const int key_stride, const int value_stride, const int num_heads,
- const int head_size, const int block_size, const int x, const float k_scale,
- const float v_scale) {
- const int64_t token_idx = blockIdx.x;
- const int64_t slot_idx = slot_mapping[token_idx];
- if (slot_idx < 0) {
- // Padding token that should be ignored.
- return;
- }
- const int64_t block_idx = slot_idx / block_size;
- const int64_t block_offset = slot_idx % block_size;
- const int n = num_heads * head_size;
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- const int64_t src_key_idx = token_idx * key_stride + i;
- const int64_t src_value_idx = token_idx * value_stride + i;
- const int head_idx = i / head_size;
- const int head_offset = i % head_size;
- const int x_idx = head_offset / x;
- const int x_offset = head_offset % x;
- const int64_t tgt_key_idx =
- block_idx * num_heads * (head_size / x) * block_size * x +
- head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
- block_offset * x + x_offset;
- const int64_t tgt_value_idx =
- block_idx * num_heads * head_size * block_size +
- head_idx * head_size * block_size + head_offset * block_size +
- block_offset;
- scalar_t tgt_key = key[src_key_idx];
- scalar_t tgt_value = value[src_value_idx];
- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
- key_cache[tgt_key_idx] = tgt_key;
- value_cache[tgt_value_idx] = tgt_value;
- } else {
- key_cache[tgt_key_idx] =
- fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
- value_cache[tgt_value_idx] =
- fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
- }
- }
- }
- template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
- __global__ void reshape_and_cache_flash_kernel(
- const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
- const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
- cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
- // head_size]
- cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
- // head_size]
- const int64_t* __restrict__ slot_mapping, // [num_tokens]
- const int block_stride, const int key_stride, const int value_stride,
- const int num_heads, const int head_size, const int block_size,
- const float k_scale, const float v_scale) {
- const int64_t token_idx = blockIdx.x;
- const int64_t slot_idx = slot_mapping[token_idx];
- // NOTE: slot_idx can be -1 if the token is padded
- if (slot_idx < 0) {
- return;
- }
- const int64_t block_idx = slot_idx / block_size;
- const int64_t block_offset = slot_idx % block_size;
- const int n = num_heads * head_size;
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
- const int64_t src_key_idx = token_idx * key_stride + i;
- const int64_t src_value_idx = token_idx * value_stride + i;
- const int head_idx = i / head_size;
- const int head_offset = i % head_size;
- const int64_t tgt_key_value_idx = block_idx * block_stride +
- block_offset * num_heads * head_size +
- head_idx * head_size + head_offset;
- scalar_t tgt_key = key[src_key_idx];
- scalar_t tgt_value = value[src_value_idx];
- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
- key_cache[tgt_key_value_idx] = tgt_key;
- value_cache[tgt_key_value_idx] = tgt_value;
- } else {
- key_cache[tgt_key_value_idx] =
- fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
- value_cache[tgt_key_value_idx] =
- fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
- }
- }
- }
- } // namespace aphrodite
- // KV_T is the stored data type of kv-cache.
- // CACHE_T is the data type of key and value tensors.
- // KV_DTYPE is the real data type of kv-cache.
- #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
- aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
- <<<grid, block, 0, stream>>>( \
- reinterpret_cast<KV_T*>(key.data_ptr()), \
- reinterpret_cast<KV_T*>(value.data_ptr()), \
- reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
- reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
- slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
- num_heads, head_size, block_size, x, k_scale, v_scale);
- void reshape_and_cache(
- torch::Tensor& key, // [num_tokens, num_heads, head_size]
- torch::Tensor& value, // [num_tokens, num_heads, head_size]
- torch::Tensor&
- key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
- torch::Tensor&
- value_cache, // [num_blocks, num_heads, head_size, block_size]
- torch::Tensor& slot_mapping, // [num_tokens]
- const std::string& kv_cache_dtype, const double k_scale,
- const double v_scale) {
- int num_tokens = key.size(0);
- int num_heads = key.size(1);
- int head_size = key.size(2);
- int block_size = key_cache.size(3);
- int x = key_cache.size(4);
- int key_stride = key.stride(0);
- int value_stride = value.stride(0);
- dim3 grid(num_tokens);
- dim3 block(std::min(num_heads * head_size, 512));
- const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
- CALL_RESHAPE_AND_CACHE)
- }
- // KV_T is the stored data type of kv-cache.
- // CACHE_T is the data type of key and value tensors.
- // KV_DTYPE is the real data type of kv-cache.
- #define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
- aphrodite::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
- <<<grid, block, 0, stream>>>( \
- reinterpret_cast<KV_T*>(key.data_ptr()), \
- reinterpret_cast<KV_T*>(value.data_ptr()), \
- reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
- reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
- slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
- value_stride, num_heads, head_size, block_size, k_scale, v_scale);
- void reshape_and_cache_flash(
- torch::Tensor& key, // [num_tokens, num_heads, head_size]
- torch::Tensor& value, // [num_tokens, num_heads, head_size]
- torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
- torch::Tensor&
- value_cache, // [num_blocks, block_size, num_heads, head_size]
- torch::Tensor& slot_mapping, // [num_tokens]
- const std::string& kv_cache_dtype, const double k_scale,
- const double v_scale) {
- int num_tokens = key.size(0);
- int num_heads = key.size(1);
- int head_size = key.size(2);
- int block_size = key_cache.size(1);
- int key_stride = key.stride(0);
- int value_stride = value.stride(0);
- int block_stride = key_cache.stride(0);
- TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
- dim3 grid(num_tokens);
- dim3 block(std::min(num_heads * head_size, 512));
- const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
- CALL_RESHAPE_AND_CACHE_FLASH);
- }
- namespace aphrodite {
- template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
- __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
- Tout* __restrict__ dst_cache,
- const float scale,
- const int64_t block_stride) {
- const int64_t block_idx = blockIdx.x;
- for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
- int64_t idx = block_idx * block_stride + i;
- dst_cache[idx] =
- fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
- }
- }
- } // namespace aphrodite
- #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
- aphrodite::convert_fp8_kernel<Tout, Tin, KV_DTYPE> \
- <<<grid, block, 0, stream>>>( \
- reinterpret_cast<Tin*>(src_cache.data_ptr()), \
- reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
- // Only for testing.
- void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
- const double scale, const std::string& kv_cache_dtype) {
- torch::Device src_device = src_cache.device();
- torch::Device dst_device = dst_cache.device();
- TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
- TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
- TORCH_CHECK(src_device.index() == dst_device.index(),
- "src and dst must be on the same GPU");
- at::cuda::OptionalCUDAGuard device_guard(src_device);
- int64_t num_blocks = src_cache.size(0);
- int64_t block_stride = src_cache.stride(0);
- dim3 grid(num_blocks);
- dim3 block(std::min(block_stride, int64_t(512)));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- if (kv_cache_dtype == "auto") {
- if (src_cache.dtype() == at::ScalarType::Float) {
- CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kAuto);
- } else if (src_cache.dtype() == at::ScalarType::Half) {
- CALL_CONVERT_FP8(uint8_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto);
- } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
- CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
- aphrodite::Fp8KVCacheDataType::kAuto);
- } else if (dst_cache.dtype() == at::ScalarType::Float) {
- CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
- } else if (dst_cache.dtype() == at::ScalarType::Half) {
- CALL_CONVERT_FP8(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kAuto);
- } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
- CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
- aphrodite::Fp8KVCacheDataType::kAuto);
- }
- } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
- if (src_cache.dtype() == at::ScalarType::Float) {
- CALL_CONVERT_FP8(uint8_t, float, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
- } else if (src_cache.dtype() == at::ScalarType::Half) {
- CALL_CONVERT_FP8(uint8_t, uint16_t,
- aphrodite::Fp8KVCacheDataType::kFp8E4M3);
- } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
- CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
- aphrodite::Fp8KVCacheDataType::kFp8E4M3);
- } else if (dst_cache.dtype() == at::ScalarType::Float) {
- CALL_CONVERT_FP8(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3);
- } else if (dst_cache.dtype() == at::ScalarType::Half) {
- CALL_CONVERT_FP8(uint16_t, uint8_t,
- aphrodite::Fp8KVCacheDataType::kFp8E4M3);
- } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
- CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
- aphrodite::Fp8KVCacheDataType::kFp8E4M3);
- }
- } else {
- TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
- }
- }
|