1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- #include <ATen/cuda/CUDAContext.h>
- #include <torch/extension.h>
- #include <cmath>
- #include "../../dispatch_utils.h"
- static inline __device__ int8_t float_to_int8_rn(float x) {
- #ifdef USE_ROCM
- static const float i8_min =
- static_cast<float>(std::numeric_limits<int8_t>::min());
- static const float i8_max =
- static_cast<float>(std::numeric_limits<int8_t>::max());
- // round
- float dst = std::nearbyint(x);
- // saturate
- dst = std::clamp(dst, i8_min, i8_max);
- return static_cast<int8_t>(dst);
- #else
- // CUDA path
- uint32_t dst;
- asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
- return reinterpret_cast<const int8_t&>(dst);
- #endif
- }
- namespace aphrodite {
- template <typename scalar_t, typename scale_type>
- __global__ void static_scaled_int8_quant_kernel(
- const scalar_t* __restrict__ input, int8_t* __restrict__ out,
- const scale_type* scale_ptr, const int hidden_size) {
- const int tid = threadIdx.x;
- const int token_idx = blockIdx.x;
- scale_type scale = *scale_ptr;
- for (int i = tid; i < hidden_size; i += blockDim.x) {
- out[token_idx * hidden_size + i] =
- float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
- }
- }
- } // namespace aphrodite
- void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
- torch::Tensor const& input, // [..., hidden_size]
- torch::Tensor const& scale) {
- TORCH_CHECK(input.is_contiguous());
- TORCH_CHECK(out.is_contiguous());
- TORCH_CHECK(scale.numel() == 1);
- int hidden_size = input.size(-1);
- int num_tokens = input.numel() / hidden_size;
- dim3 grid(num_tokens);
- dim3 block(std::min(hidden_size, 1024));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- APHRODITE_DISPATCH_FLOATING_TYPES(
- input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
- aphrodite::static_scaled_int8_quant_kernel<scalar_t, float>
- <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
- out.data_ptr<int8_t>(),
- scale.data_ptr<float>(), hidden_size);
- });
- }
|