int8_quant_kernels.cu 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <cmath>
  4. #include "../../dispatch_utils.h"
  5. #include "../../reduction.cuh"
  6. static inline __device__ int8_t float_to_int8_rn(float x) {
  7. #ifdef USE_ROCM
  8. static const float i8_min =
  9. static_cast<float>(std::numeric_limits<int8_t>::min());
  10. static const float i8_max =
  11. static_cast<float>(std::numeric_limits<int8_t>::max());
  12. // round
  13. float dst = std::nearbyint(x);
  14. // saturate
  15. dst = std::clamp(dst, i8_min, i8_max);
  16. return static_cast<int8_t>(dst);
  17. #else
  18. // CUDA path
  19. uint32_t dst;
  20. asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
  21. return reinterpret_cast<const int8_t&>(dst);
  22. #endif
  23. }
  24. namespace aphrodite {
  25. template <typename scalar_t, typename scale_type>
  26. __global__ void static_scaled_int8_quant_kernel(
  27. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  28. scale_type const* scale_ptr, const int hidden_size) {
  29. int const tid = threadIdx.x;
  30. int const token_idx = blockIdx.x;
  31. scale_type const scale = *scale_ptr;
  32. for (int i = tid; i < hidden_size; i += blockDim.x) {
  33. out[token_idx * hidden_size + i] = float_to_int8_rn(
  34. static_cast<float>(input[token_idx * hidden_size + i]) / scale);
  35. }
  36. }
  37. template <typename scalar_t, typename scale_type>
  38. __global__ void dynamic_scaled_int8_quant_kernel(
  39. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  40. scale_type* scale, const int hidden_size) {
  41. int const tid = threadIdx.x;
  42. int const token_idx = blockIdx.x;
  43. float absmax_val = 0.0f;
  44. float const zero = 0.0f;
  45. for (int i = tid; i < hidden_size; i += blockDim.x) {
  46. float val = static_cast<float>(input[token_idx * hidden_size + i]);
  47. val = val > zero ? val : -val;
  48. absmax_val = val > absmax_val ? val : absmax_val;
  49. }
  50. float const block_absmax_val_maybe = blockReduceMax(absmax_val);
  51. __shared__ float block_absmax_val;
  52. if (tid == 0) {
  53. block_absmax_val = block_absmax_val_maybe;
  54. scale[token_idx] = block_absmax_val / 127.0f;
  55. }
  56. __syncthreads();
  57. float const tmp_scale = 127.0f / block_absmax_val;
  58. for (int i = tid; i < hidden_size; i += blockDim.x) {
  59. out[token_idx * hidden_size + i] = float_to_int8_rn(
  60. static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
  61. }
  62. }
  63. } // namespace aphrodite
  64. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
  65. torch::Tensor const& input, // [..., hidden_size]
  66. torch::Tensor const& scale) {
  67. TORCH_CHECK(input.is_contiguous());
  68. TORCH_CHECK(out.is_contiguous());
  69. TORCH_CHECK(scale.numel() == 1);
  70. int const hidden_size = input.size(-1);
  71. int const num_tokens = input.numel() / hidden_size;
  72. dim3 const grid(num_tokens);
  73. dim3 const block(std::min(hidden_size, 1024));
  74. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  75. APHRODITE_DISPATCH_FLOATING_TYPES(
  76. input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
  77. aphrodite::static_scaled_int8_quant_kernel<scalar_t, float>
  78. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
  79. out.data_ptr<int8_t>(),
  80. scale.data_ptr<float>(), hidden_size);
  81. });
  82. }
  83. void dynamic_scaled_int8_quant(
  84. torch::Tensor& out, // [..., hidden_size]
  85. torch::Tensor const& input, // [..., hidden_size]
  86. torch::Tensor& scales) {
  87. TORCH_CHECK(input.is_contiguous());
  88. TORCH_CHECK(out.is_contiguous());
  89. int const hidden_size = input.size(-1);
  90. int const num_tokens = input.numel() / hidden_size;
  91. dim3 const grid(num_tokens);
  92. dim3 const block(std::min(hidden_size, 1024));
  93. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  94. APHRODITE_DISPATCH_FLOATING_TYPES(
  95. input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
  96. aphrodite::dynamic_scaled_int8_quant_kernel<scalar_t, float>
  97. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
  98. out.data_ptr<int8_t>(),
  99. scales.data_ptr<float>(), hidden_size);
  100. });
  101. }