int8_quant_kernels.cu 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <cmath>
  4. #include "../../dispatch_utils.h"
  5. #ifndef USE_ROCM
  6. #include <cub/util_type.cuh>
  7. #include <cub/cub.cuh>
  8. #else
  9. #include <hipcub/util_type.hpp>
  10. #include <hipcub/hipcub.hpp>
  11. #endif
  12. static inline __device__ int8_t float_to_int8_rn(float x) {
  13. #ifdef USE_ROCM
  14. static const float i8_min =
  15. static_cast<float>(std::numeric_limits<int8_t>::min());
  16. static const float i8_max =
  17. static_cast<float>(std::numeric_limits<int8_t>::max());
  18. // round
  19. float dst = std::nearbyint(x);
  20. // saturate
  21. dst = std::clamp(dst, i8_min, i8_max);
  22. return static_cast<int8_t>(dst);
  23. #else
  24. // CUDA path
  25. uint32_t dst;
  26. asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
  27. return reinterpret_cast<const int8_t&>(dst);
  28. #endif
  29. }
  30. namespace aphrodite {
  31. template <typename scalar_t, typename scale_type>
  32. __global__ void static_scaled_int8_quant_kernel(
  33. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  34. scale_type const* scale_ptr, const int hidden_size) {
  35. int const tid = threadIdx.x;
  36. int const token_idx = blockIdx.x;
  37. scale_type const scale = *scale_ptr;
  38. for (int i = tid; i < hidden_size; i += blockDim.x) {
  39. out[token_idx * hidden_size + i] = float_to_int8_rn(
  40. static_cast<float>(input[token_idx * hidden_size + i]) / scale);
  41. }
  42. }
  43. template <typename scalar_t, typename scale_type>
  44. __global__ void dynamic_scaled_int8_quant_kernel(
  45. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  46. scale_type* scale, const int hidden_size) {
  47. int const tid = threadIdx.x;
  48. int const token_idx = blockIdx.x;
  49. float absmax_val = 0.0f;
  50. float const zero = 0.0f;
  51. for (int i = tid; i < hidden_size; i += blockDim.x) {
  52. float val = static_cast<float>(input[token_idx * hidden_size + i]);
  53. val = val > zero ? val : -val;
  54. absmax_val = val > absmax_val ? val : absmax_val;
  55. }
  56. using BlockReduce = cub::BlockReduce<float, 1024>;
  57. __shared__ typename BlockReduce::TempStorage reduceStorage;
  58. float const block_absmax_val_maybe =
  59. BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
  60. __shared__ float block_absmax_val;
  61. if (tid == 0) {
  62. block_absmax_val = block_absmax_val_maybe;
  63. scale[token_idx] = block_absmax_val / 127.0f;
  64. }
  65. __syncthreads();
  66. float const tmp_scale = 127.0f / block_absmax_val;
  67. for (int i = tid; i < hidden_size; i += blockDim.x) {
  68. out[token_idx * hidden_size + i] = float_to_int8_rn(
  69. static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
  70. }
  71. }
  72. } // namespace aphrodite
  73. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
  74. torch::Tensor const& input, // [..., hidden_size]
  75. torch::Tensor const& scale) {
  76. TORCH_CHECK(input.is_contiguous());
  77. TORCH_CHECK(out.is_contiguous());
  78. TORCH_CHECK(scale.numel() == 1);
  79. int const hidden_size = input.size(-1);
  80. int const num_tokens = input.numel() / hidden_size;
  81. dim3 const grid(num_tokens);
  82. dim3 const block(std::min(hidden_size, 1024));
  83. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  84. APHRODITE_DISPATCH_FLOATING_TYPES(
  85. input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
  86. aphrodite::static_scaled_int8_quant_kernel<scalar_t, float>
  87. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
  88. out.data_ptr<int8_t>(),
  89. scale.data_ptr<float>(), hidden_size);
  90. });
  91. }
  92. void dynamic_scaled_int8_quant(
  93. torch::Tensor& out, // [..., hidden_size]
  94. torch::Tensor const& input, // [..., hidden_size]
  95. torch::Tensor& scales) {
  96. TORCH_CHECK(input.is_contiguous());
  97. TORCH_CHECK(out.is_contiguous());
  98. int const hidden_size = input.size(-1);
  99. int const num_tokens = input.numel() / hidden_size;
  100. dim3 const grid(num_tokens);
  101. dim3 const block(std::min(hidden_size, 1024));
  102. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  103. APHRODITE_DISPATCH_FLOATING_TYPES(
  104. input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
  105. aphrodite::dynamic_scaled_int8_quant_kernel<scalar_t, float>
  106. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
  107. out.data_ptr<int8_t>(),
  108. scales.data_ptr<float>(), hidden_size);
  109. });
  110. }