#include #include #include #include "../../dispatch_utils.h" #include "../../reduction.cuh" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM static const float i8_min = static_cast(std::numeric_limits::min()); static const float i8_max = static_cast(std::numeric_limits::max()); // round float dst = std::nearbyint(x); // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); #else // CUDA path uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); #endif } namespace aphrodite { template __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) / scale); } } template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; int const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; for (int i = tid; i < hidden_size; i += blockDim.x) { float val = static_cast(input[token_idx * hidden_size + i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; } float const block_absmax_val_maybe = blockReduceMax(absmax_val); __shared__ float block_absmax_val; if (tid == 0) { block_absmax_val = block_absmax_val_maybe; scale[token_idx] = block_absmax_val / 127.0f; } __syncthreads(); float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = float_to_int8_rn( static_cast(input[token_idx * hidden_size + i]) * tmp_scale); } } } // namespace aphrodite void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& scale) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { aphrodite::static_scaled_int8_quant_kernel <<>>(input.data_ptr(), out.data_ptr(), scale.data_ptr(), hidden_size); }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] torch::Tensor& scales) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 1024)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { aphrodite::dynamic_scaled_int8_quant_kernel <<>>(input.data_ptr(), out.data_ptr(), scales.data_ptr(), hidden_size); }); }