123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- #include <ATen/cuda/CUDAContext.h>
- #include <torch/all.h>
- #include <cmath>
- #include "../../dispatch_utils.h"
- #include "../../reduction.cuh"
- static inline __device__ int8_t float_to_int8_rn(float x) {
- #ifdef USE_ROCM
- static const float i8_min =
- static_cast<float>(std::numeric_limits<int8_t>::min());
- static const float i8_max =
- static_cast<float>(std::numeric_limits<int8_t>::max());
- // round
- float dst = std::nearbyint(x);
- // saturate
- dst = std::clamp(dst, i8_min, i8_max);
- return static_cast<int8_t>(dst);
- #else
- // CUDA path
- uint32_t dst;
- asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
- return reinterpret_cast<const int8_t&>(dst);
- #endif
- }
- namespace aphrodite {
- template <typename scalar_t, typename scale_type>
- __global__ void static_scaled_int8_quant_kernel(
- scalar_t const* __restrict__ input, int8_t* __restrict__ out,
- scale_type const* scale_ptr, const int hidden_size) {
- int const tid = threadIdx.x;
- int const token_idx = blockIdx.x;
- scale_type const scale = *scale_ptr;
- for (int i = tid; i < hidden_size; i += blockDim.x) {
- out[token_idx * hidden_size + i] = float_to_int8_rn(
- static_cast<float>(input[token_idx * hidden_size + i]) / scale);
- }
- }
- template <typename scalar_t, typename scale_type>
- __global__ void dynamic_scaled_int8_quant_kernel(
- scalar_t const* __restrict__ input, int8_t* __restrict__ out,
- scale_type* scale, const int hidden_size) {
- int const tid = threadIdx.x;
- int const token_idx = blockIdx.x;
- float absmax_val = 0.0f;
- float const zero = 0.0f;
- for (int i = tid; i < hidden_size; i += blockDim.x) {
- float val = static_cast<float>(input[token_idx * hidden_size + i]);
- val = val > zero ? val : -val;
- absmax_val = val > absmax_val ? val : absmax_val;
- }
- float const block_absmax_val_maybe = blockReduceMax(absmax_val);
- __shared__ float block_absmax_val;
- if (tid == 0) {
- block_absmax_val = block_absmax_val_maybe;
- scale[token_idx] = block_absmax_val / 127.0f;
- }
- __syncthreads();
- float const tmp_scale = 127.0f / block_absmax_val;
- for (int i = tid; i < hidden_size; i += blockDim.x) {
- out[token_idx * hidden_size + i] = float_to_int8_rn(
- static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
- }
- }
- } // namespace aphrodite
- void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
- torch::Tensor const& input, // [..., hidden_size]
- torch::Tensor const& scale) {
- TORCH_CHECK(input.is_contiguous());
- TORCH_CHECK(out.is_contiguous());
- TORCH_CHECK(scale.numel() == 1);
- int const hidden_size = input.size(-1);
- int const num_tokens = input.numel() / hidden_size;
- dim3 const grid(num_tokens);
- dim3 const block(std::min(hidden_size, 1024));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- APHRODITE_DISPATCH_FLOATING_TYPES(
- input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
- aphrodite::static_scaled_int8_quant_kernel<scalar_t, float>
- <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
- out.data_ptr<int8_t>(),
- scale.data_ptr<float>(), hidden_size);
- });
- }
- void dynamic_scaled_int8_quant(
- torch::Tensor& out, // [..., hidden_size]
- torch::Tensor const& input, // [..., hidden_size]
- torch::Tensor& scales) {
- TORCH_CHECK(input.is_contiguous());
- TORCH_CHECK(out.is_contiguous());
- int const hidden_size = input.size(-1);
- int const num_tokens = input.numel() / hidden_size;
- dim3 const grid(num_tokens);
- dim3 const block(std::min(hidden_size, 1024));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- APHRODITE_DISPATCH_FLOATING_TYPES(
- input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
- aphrodite::dynamic_scaled_int8_quant_kernel<scalar_t, float>
- <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
- out.data_ptr<int8_t>(),
- scales.data_ptr<float>(), hidden_size);
- });
- }
|