1
0

quant.cpp 11 KB


  1. #include "cpu_types.hpp"
  2. #include "dnnl_helper.hpp"
  3. namespace {
  4. template <typename scalar_t>
  5. struct KernelVecType {
  6. using load_vec_type = void;
  7. using cvt_vec_type = void;
  8. };
  9. template <>
  10. struct KernelVecType<float> {
  11. using load_vec_type = vec_op::FP32Vec16;
  12. using cvt_vec_type = vec_op::FP32Vec16;
  13. };
  14. template <>
  15. struct KernelVecType<c10::BFloat16> {
  16. using load_vec_type = vec_op::BF16Vec16;
  17. using cvt_vec_type = vec_op::FP32Vec16;
  18. };
  19. #ifdef __AVX512F__
  20. template <typename scalar_t>
  21. void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
  22. const float* scale, const int num_tokens,
  23. const int hidden_size) {
  24. using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
  25. using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  26. constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
  27. constexpr float i8_min =
  28. static_cast<float>(std::numeric_limits<int8_t>::min());
  29. constexpr float i8_max =
  30. static_cast<float>(std::numeric_limits<int8_t>::max());
  31. const cvt_vec_t inv_scale(1.0 / *scale);
  32. const cvt_vec_t i8_min_vec(i8_min);
  33. const cvt_vec_t i8_max_vec(i8_max);
  34. #pragma omp parallel for
  35. for (int i = 0; i < num_tokens; ++i) {
  36. int j = 0;
  37. for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
  38. load_vec_t elems(input + i * hidden_size + j);
  39. cvt_vec_t elems_fp32(elems);
  40. elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
  41. vec_op::INT8Vec16 elems_int8(elems_fp32);
  42. elems_int8.save(output + i * hidden_size + j);
  43. }
  44. load_vec_t elems(input + i * hidden_size + j);
  45. cvt_vec_t elems_fp32(elems);
  46. elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
  47. vec_op::INT8Vec16 elems_int8(elems_fp32);
  48. if (j + vec_elem_num == hidden_size) {
  49. elems_int8.save(output + i * hidden_size + j);
  50. } else {
  51. elems_int8.save(output + i * hidden_size + j, hidden_size - j);
  52. }
  53. }
  54. }
  55. template <typename scalar_t>
  56. void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
  57. float* scale, const int num_tokens,
  58. const int hidden_size) {
  59. using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
  60. using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  61. constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
  62. #pragma omp parallel for
  63. for (int i = 0; i < num_tokens; ++i) {
  64. cvt_vec_t max_abs(0.0);
  65. {
  66. int j = 0;
  67. for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
  68. load_vec_t elems(input + i * hidden_size + j);
  69. cvt_vec_t elems_fp32(elems);
  70. max_abs = max_abs.max(elems_fp32.abs());
  71. }
  72. load_vec_t elems(input + i * hidden_size + j);
  73. cvt_vec_t elems_fp32(elems);
  74. if (j + vec_elem_num == hidden_size) {
  75. max_abs = max_abs.max(elems_fp32.abs());
  76. } else {
  77. max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j);
  78. }
  79. }
  80. float scale_val = max_abs.reduce_max() / 127.0f;
  81. scale[i] = scale_val;
  82. const cvt_vec_t inv_scale(1.0 / scale_val);
  83. {
  84. int j = 0;
  85. for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
  86. load_vec_t elems(input + i * hidden_size + j);
  87. cvt_vec_t elems_fp32(elems);
  88. elems_fp32 = (elems_fp32 * inv_scale);
  89. vec_op::INT8Vec16 elems_int8(elems_fp32);
  90. elems_int8.save(output + i * hidden_size + j);
  91. }
  92. load_vec_t elems(input + i * hidden_size + j);
  93. cvt_vec_t elems_fp32(elems);
  94. elems_fp32 = (elems_fp32 * inv_scale);
  95. vec_op::INT8Vec16 elems_int8(elems_fp32);
  96. if (j + vec_elem_num == hidden_size) {
  97. elems_int8.save(output + i * hidden_size + j);
  98. } else {
  99. elems_int8.save(output + i * hidden_size + j, hidden_size - j);
  100. }
  101. }
  102. }
  103. }
  104. template <bool Bias, typename scalar_t>
  105. void dynamic_output_scale_impl(const float* input, scalar_t* output,
  106. const float* scale, const scalar_t* bias,
  107. const int num_tokens, const int hidden_size) {
  108. CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
  109. using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
  110. using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
  111. constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
  112. #pragma omp parallel for
  113. for (int i = 0; i < num_tokens; ++i) {
  114. int j = 0;
  115. cvt_vec_t token_scale_vec(scale[i]);
  116. for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
  117. cvt_vec_t elems_fp32(input + i * hidden_size + j);
  118. elems_fp32 = elems_fp32 * token_scale_vec;
  119. if constexpr (Bias) {
  120. load_vec_t bias_vec(bias + j);
  121. cvt_vec_t bias_vec_fp32(bias_vec);
  122. elems_fp32 = elems_fp32 + bias_vec_fp32;
  123. }
  124. load_vec_t elems_out(elems_fp32);
  125. elems_out.save(output + i * hidden_size + j);
  126. }
  127. cvt_vec_t elems_fp32(input + i * hidden_size + j);
  128. elems_fp32 = elems_fp32 * token_scale_vec;
  129. if constexpr (Bias) {
  130. load_vec_t bias_vec(bias + j);
  131. cvt_vec_t bias_vec_fp32(bias_vec);
  132. elems_fp32 = elems_fp32 + bias_vec_fp32;
  133. }
  134. load_vec_t elems_out(elems_fp32);
  135. if (j + vec_elem_num == hidden_size) {
  136. elems_out.save(output + i * hidden_size + j);
  137. } else {
  138. elems_out.save(output + i * hidden_size + j, hidden_size - j);
  139. }
  140. }
  141. }
  142. #else
  143. template <typename scalar_t>
  144. void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
  145. const float* scale, const int num_tokens,
  146. const int hidden_size) {
  147. TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
  148. }
  149. template <typename scalar_t>
  150. void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
  151. float* scale, const int num_tokens,
  152. const int hidden_size) {
  153. TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
  154. }
  155. template <typename scalar_t>
  156. void dynamic_output_scale_impl() {
  157. TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.")
  158. }
  159. #endif
  160. } // namespace
  161. void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
  162. const torch::Tensor& a, // [M, IC], row-major
  163. const torch::Tensor& b, // [IC, OC], column-major
  164. const torch::Tensor& a_scales, // [1] or [M]
  165. const torch::Tensor& b_scales, // [1] or [OC]
  166. const c10::optional<torch::Tensor>& bias // [OC]
  167. ) {
  168. CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
  169. // Checks for conformality
  170. TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
  171. "int8_scaled_mm only supports INT8 inputs.")
  172. TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  173. TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
  174. b.size(1) == c.size(1));
  175. TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  176. TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
  177. // Check for strides and alignment
  178. TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
  179. TORCH_CHECK(b.stride(0) == 1); // Column-major
  180. TORCH_CHECK(c.stride(0) % 16 == 0 &&
  181. b.stride(1) % 16 == 0); // 16 Byte Alignment
  182. TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
  183. if (bias) {
  184. TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
  185. bias->dim() == 1);
  186. }
  187. APHRODITE_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] {
  188. if (a_scales.numel() != 1) {
  189. // per-token
  190. // Note: oneDNN doesn't support per-token activation quantization
  191. torch::Tensor tmp_fp32_out =
  192. torch::empty_like(c, ::at::ScalarType::Float);
  193. DNNLPrimitiveHelper<true>::gemm_s8s8_jit(
  194. a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
  195. tmp_fp32_out.data_ptr<float>(), (void*)(0), a.size(0), b.size(1),
  196. a.size(1), (float*)(0), b_scales.data_ptr<float>(), 0,
  197. b_scales.numel());
  198. if (bias.has_value()) {
  199. dynamic_output_scale_impl<true>(
  200. tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
  201. a_scales.data_ptr<float>(), bias->data_ptr<scalar_t>(), c.size(0),
  202. c.size(1));
  203. } else {
  204. dynamic_output_scale_impl<false>(
  205. tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
  206. a_scales.data_ptr<float>(), (scalar_t*)(0), c.size(0), c.size(1));
  207. }
  208. } else {
  209. // per-tensor
  210. if (bias.has_value()) {
  211. DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
  212. a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
  213. bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
  214. a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
  215. a_scales.numel(), b_scales.numel());
  216. } else {
  217. DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
  218. a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
  219. (void*)(0), a.size(0), b.size(1), a.size(1),
  220. a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
  221. a_scales.numel(), b_scales.numel());
  222. }
  223. }
  224. });
  225. }
  226. // static-per-tensor quantization.
  227. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
  228. const torch::Tensor& input, // [..., hidden_size]
  229. const torch::Tensor& scale,
  230. c10::optional<torch::Tensor> const& azp) {
  231. CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
  232. TORCH_CHECK(input.is_contiguous());
  233. TORCH_CHECK(out.is_contiguous());
  234. TORCH_CHECK(scale.numel() == 1);
  235. TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
  236. const int hidden_size = input.size(-1);
  237. const int num_tokens = input.numel() / hidden_size;
  238. APHRODITE_DISPATCH_FLOATING_TYPES(
  239. input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
  240. static_scaled_int8_quant_impl(
  241. input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
  242. scale.data_ptr<float>(), num_tokens, hidden_size);
  243. });
  244. }
  245. // dynamic-per-token quantization.
  246. void dynamic_scaled_int8_quant(
  247. torch::Tensor& out, // [..., hidden_size]
  248. const torch::Tensor& input, // [..., hidden_size]
  249. torch::Tensor& scale, // [..., 1]
  250. c10::optional<torch::Tensor> const& azp) {
  251. CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
  252. TORCH_CHECK(input.is_contiguous());
  253. TORCH_CHECK(out.is_contiguous());
  254. TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
  255. int const hidden_size = input.size(-1);
  256. int const num_tokens = input.numel() / hidden_size;
  257. APHRODITE_DISPATCH_FLOATING_TYPES(
  258. input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
  259. dynamic_scaled_int8_quant_impl(
  260. input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
  261. scale.data_ptr<float>(), num_tokens, hidden_size);
  262. });
  263. }