int8_quant_kernels.cu 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <cmath>
  4. #include "../../dispatch_utils.h"
  5. #ifndef USE_ROCM
  6. #include <cub/util_type.cuh>
  7. #include <cub/cub.cuh>
  8. #else
  9. #include <hipcub/util_type.hpp>
  10. #include <hipcub/hipcub.hpp>
  11. #endif
  12. static inline __device__ int8_t float_to_int8_rn(float x) {
  13. #ifdef USE_ROCM
  14. static constexpr auto i8_min =
  15. static_cast<float>(std::numeric_limits<int8_t>::min());
  16. static constexpr auto i8_max =
  17. static_cast<float>(std::numeric_limits<int8_t>::max());
  18. // To match the rounding mode of CUDA, we use nearbyint.
  19. // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
  20. // If that changes in the future, we may need to set the rounding mode
  21. // explicitly, either at runtime or compile time.
  22. float dst = std::nearbyint(x);
  23. // saturate
  24. dst = std::clamp(dst, i8_min, i8_max);
  25. return static_cast<int8_t>(dst);
  26. #else
  27. // CUDA path
  28. uint32_t dst;
  29. asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
  30. return reinterpret_cast<const int8_t&>(dst);
  31. #endif
  32. }
  33. static inline __device__ int32_t float_to_int32_rn(float x) {
  34. #ifdef USE_ROCM
  35. // int32_max is not exactly representable as float.
  36. // Therefore, we need to be careful and manually return int32_max on overflow.
  37. // For symmetry, we also do the same for int32_min, even though it is exactly
  38. // representable as float and the conversion should be exact.
  39. static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
  40. static constexpr auto i32_min_f = static_cast<float>(i32_min);
  41. static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
  42. static constexpr auto i32_max_f = static_cast<float>(i32_max);
  43. // To match the rounding mode of CUDA, we use nearbyint.
  44. // It uses the current rounding mode, which is always FE_TONEAREST on HIP.
  45. // If that changes in the future, we may need to set the rounding mode
  46. // explicitly, either at runtime or compile time.
  47. float dst = std::nearbyint(x);
  48. // saturate on the higher end.
  49. if (dst >= i32_max_f) {
  50. return i32_max;
  51. }
  52. // saturate on the lower end.
  53. if (dst <= i32_min_f) {
  54. return i32_min;
  55. }
  56. return static_cast<int32_t>(dst);
  57. #else
  58. // CUDA path
  59. uint32_t dst;
  60. asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
  61. return reinterpret_cast<const int32_t&>(dst);
  62. #endif
  63. }
  64. static inline __device__ int8_t int32_to_int8(int32_t x) {
  65. #ifdef USE_ROCM
  66. static constexpr auto i8_min =
  67. static_cast<int32_t>(std::numeric_limits<int8_t>::min());
  68. static constexpr auto i8_max =
  69. static_cast<int32_t>(std::numeric_limits<int8_t>::max());
  70. // saturate
  71. int32_t dst = std::clamp(x, i8_min, i8_max);
  72. return static_cast<int8_t>(dst);
  73. #else
  74. // CUDA path
  75. uint32_t dst;
  76. asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
  77. return reinterpret_cast<const int8_t&>(dst);
  78. #endif
  79. }
  80. namespace aphrodite {
  81. template <typename scalar_t, typename scale_type>
  82. __global__ void static_scaled_int8_quant_kernel(
  83. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  84. scale_type const* scale_ptr, const int hidden_size) {
  85. int const tid = threadIdx.x;
  86. int const token_idx = blockIdx.x;
  87. scale_type const scale = *scale_ptr;
  88. for (int i = tid; i < hidden_size; i += blockDim.x) {
  89. out[token_idx * hidden_size + i] = float_to_int8_rn(
  90. static_cast<float>(input[token_idx * hidden_size + i]) / scale);
  91. }
  92. }
  93. template <typename scalar_t, typename scale_type, typename azp_type>
  94. __global__ void static_scaled_int8_azp_quant_kernel(
  95. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  96. scale_type const* scale_ptr, azp_type const* azp_ptr,
  97. const int hidden_size) {
  98. int const tid = threadIdx.x;
  99. int const token_idx = blockIdx.x;
  100. scale_type const scale = *scale_ptr;
  101. azp_type const azp = *azp_ptr;
  102. for (int i = tid; i < hidden_size; i += blockDim.x) {
  103. auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
  104. auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
  105. out[token_idx * hidden_size + i] = quant_val;
  106. }
  107. }
  108. template <typename scalar_t, typename scale_type>
  109. __global__ void dynamic_scaled_int8_quant_kernel(
  110. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  111. scale_type* scale, const int hidden_size) {
  112. int const tid = threadIdx.x;
  113. int const token_idx = blockIdx.x;
  114. float absmax_val = 0.0f;
  115. float const zero = 0.0f;
  116. for (int i = tid; i < hidden_size; i += blockDim.x) {
  117. float val = static_cast<float>(input[token_idx * hidden_size + i]);
  118. val = val > zero ? val : -val;
  119. absmax_val = val > absmax_val ? val : absmax_val;
  120. }
  121. using BlockReduce = cub::BlockReduce<float, 1024>;
  122. __shared__ typename BlockReduce::TempStorage reduceStorage;
  123. float const block_absmax_val_maybe =
  124. BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
  125. __shared__ float block_absmax_val;
  126. if (tid == 0) {
  127. block_absmax_val = block_absmax_val_maybe;
  128. scale[token_idx] = block_absmax_val / 127.0f;
  129. }
  130. __syncthreads();
  131. float const tmp_scale = 127.0f / block_absmax_val;
  132. for (int i = tid; i < hidden_size; i += blockDim.x) {
  133. out[token_idx * hidden_size + i] = float_to_int8_rn(
  134. static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
  135. }
  136. }
  137. template <typename scalar_t, typename scale_type, typename azp_type>
  138. __global__ void dynamic_scaled_int8_azp_quant_kernel(
  139. scalar_t const* __restrict__ input, int8_t* __restrict__ out,
  140. scale_type* scale, azp_type* azp, const int hidden_size) {
  141. int const token_idx = blockIdx.x;
  142. // Scan for the min and max value for this token
  143. float max_val = std::numeric_limits<float>::min();
  144. float min_val = std::numeric_limits<float>::max();
  145. for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
  146. auto val = static_cast<float>(input[token_idx * hidden_size + i]);
  147. max_val = std::max(max_val, val);
  148. min_val = std::min(min_val, val);
  149. }
  150. // Reduce the max and min values across the block
  151. using BlockReduce = cub::BlockReduce<float, 1024>;
  152. __shared__ typename BlockReduce::TempStorage reduceStorage;
  153. max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
  154. __syncthreads(); // Make sure min doesn't mess with max shared memory
  155. min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
  156. __shared__ scale_type scale_sh;
  157. __shared__ azp_type azp_sh;
  158. // Compute the scale and zero point and store them, only on the first thread
  159. if (threadIdx.x == 0) {
  160. float const scale_val = (max_val - min_val) / 255.0f;
  161. // Use rounding to even (same as torch.round)
  162. auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
  163. auto const azp_val = static_cast<azp_type>(azp_float);
  164. // Store the scale and azp into shared and global
  165. scale[token_idx] = scale_sh = scale_val;
  166. azp[token_idx] = azp_sh = azp_val;
  167. }
  168. // Wait for the scale and azp to be computed
  169. __syncthreads();
  170. float const scale_val = scale_sh;
  171. azp_type const azp_val = azp_sh;
  172. // Quantize the values
  173. for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
  174. auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
  175. auto const quant_val =
  176. int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
  177. out[token_idx * hidden_size + i] = quant_val;
  178. }
  179. }
  180. } // namespace aphrodite
  181. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
  182. torch::Tensor const& input, // [..., hidden_size]
  183. torch::Tensor const& scale,
  184. c10::optional<torch::Tensor> const& azp) {
  185. TORCH_CHECK(input.is_contiguous());
  186. TORCH_CHECK(out.is_contiguous());
  187. TORCH_CHECK(scale.numel() == 1);
  188. TORCH_CHECK(!azp || azp->numel() == 1);
  189. int const hidden_size = input.size(-1);
  190. int const num_tokens = input.numel() / hidden_size;
  191. dim3 const grid(num_tokens);
  192. dim3 const block(std::min(hidden_size, 1024));
  193. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  194. APHRODITE_DISPATCH_FLOATING_TYPES(
  195. input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
  196. if (!azp) {
  197. aphrodite::static_scaled_int8_quant_kernel<scalar_t, float>
  198. <<<grid, block, 0, stream>>>(
  199. input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
  200. scale.data_ptr<float>(), hidden_size);
  201. } else {
  202. aphrodite::static_scaled_int8_azp_quant_kernel<scalar_t, float,
  203. int32_t>
  204. <<<grid, block, 0, stream>>>(
  205. input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
  206. scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
  207. hidden_size);
  208. }
  209. });
  210. }
  211. void dynamic_scaled_int8_quant(
  212. torch::Tensor& out, // [..., hidden_size]
  213. torch::Tensor const& input, // [..., hidden_size]
  214. torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
  215. TORCH_CHECK(input.is_contiguous());
  216. TORCH_CHECK(out.is_contiguous());
  217. TORCH_CHECK(scales.is_contiguous());
  218. TORCH_CHECK(!azp || azp->is_contiguous());
  219. int const hidden_size = input.size(-1);
  220. int const num_tokens = input.numel() / hidden_size;
  221. dim3 const grid(num_tokens);
  222. dim3 const block(std::min(hidden_size, 1024));
  223. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  224. APHRODITE_DISPATCH_FLOATING_TYPES(
  225. input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
  226. if (!azp) {
  227. aphrodite::dynamic_scaled_int8_quant_kernel<scalar_t, float>
  228. <<<grid, block, 0, stream>>>(
  229. input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
  230. scales.data_ptr<float>(), hidden_size);
  231. } else {
  232. aphrodite::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float,
  233. int32_t>
  234. <<<grid, block, 0, stream>>>(
  235. input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
  236. scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
  237. hidden_size);
  238. }
  239. });
  240. }