int8_quant_kernels.cu 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/extension.h>
  3. #include <cmath>
  4. #include "../../dispatch_utils.h"
  5. static inline __device__ int8_t float_to_int8_rn(float x) {
  6. #ifdef USE_ROCM
  7. static const float i8_min =
  8. static_cast<float>(std::numeric_limits<int8_t>::min());
  9. static const float i8_max =
  10. static_cast<float>(std::numeric_limits<int8_t>::max());
  11. // round
  12. float dst = std::nearbyint(x);
  13. // saturate
  14. dst = std::clamp(dst, i8_min, i8_max);
  15. return static_cast<int8_t>(dst);
  16. #else
  17. // CUDA path
  18. uint32_t dst;
  19. asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
  20. return reinterpret_cast<const int8_t&>(dst);
  21. #endif
  22. }
  23. namespace aphrodite {
  24. template <typename scalar_t, typename scale_type>
  25. __global__ void static_scaled_int8_quant_kernel(
  26. const scalar_t* __restrict__ input, int8_t* __restrict__ out,
  27. scale_type scale, const int hidden_size) {
  28. const int tid = threadIdx.x;
  29. const int token_idx = blockIdx.x;
  30. for (int i = tid; i < hidden_size; i += blockDim.x) {
  31. out[token_idx * hidden_size + i] =
  32. float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
  33. }
  34. }
  35. } // namespace aphrodite
  36. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
  37. torch::Tensor& input, // [..., hidden_size]
  38. float scale) {
  39. TORCH_CHECK(input.is_contiguous());
  40. TORCH_CHECK(out.is_contiguous());
  41. int hidden_size = input.size(-1);
  42. int num_tokens = input.numel() / hidden_size;
  43. dim3 grid(num_tokens);
  44. dim3 block(std::min(hidden_size, 1024));
  45. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  46. APHRODITE_DISPATCH_FLOATING_TYPES(
  47. input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
  48. aphrodite::static_scaled_int8_quant_kernel<scalar_t, float>
  49. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
  50. out.data_ptr<int8_t>(), scale,
  51. hidden_size);
  52. });
  53. }