quant_ops.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. #pragma once
  2. #include <torch/library.h>
  3. #include "core/scalar_type.hpp"
  4. #ifndef USE_ROCM
  5. // AQLM
  6. torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
  7. const torch::Tensor& codebooks,
  8. const torch::Tensor& scales,
  9. const std::vector<int64_t>& codebook_partition_sizes,
  10. const std::optional<torch::Tensor>& bias);
  11. torch::Tensor aqlm_dequant(
  12. const torch::Tensor& codes, const torch::Tensor& codebooks,
  13. const std::vector<int64_t>& codebook_partition_sizes);
  14. // AWQ
  15. torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
  16. torch::Tensor _scaling_factors, torch::Tensor _zeros,
  17. int64_t split_k_iters);
  18. torch::Tensor awq_dequantize(torch::Tensor _kernel,
  19. torch::Tensor _scaling_factors,
  20. torch::Tensor _zeros, int64_t split_k_iters,
  21. int64_t thx, int64_t thy);
  22. torch::Tensor awq_group_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
  23. torch::Tensor _scaling_factors,
  24. torch::Tensor _zeros, torch::Tensor _topk_weights,
  25. torch::Tensor _sorted_token_ids_ptr,
  26. torch::Tensor _expert_ids_ptr,
  27. torch::Tensor _num_tokens_post_padded,
  28. bool mul_weights, int64_t split_k_iters);
  29. #endif
  30. // GPTQ
  31. torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
  32. torch::Tensor b_gptq_qzeros,
  33. torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
  34. bool use_exllama, int64_t bit);
  35. void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
  36. torch::Tensor group_gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
  37. torch::Tensor b_gptq_qzeros,
  38. torch::Tensor b_gptq_scales,
  39. torch::Tensor b_g_idx, torch::Tensor topk_weights,
  40. torch::Tensor sorted_token_ids_ptr,
  41. torch::Tensor expert_ids_ptr,
  42. torch::Tensor num_tokens_post_padded,
  43. bool mul_weights, bool use_exllama);
  44. torch::Tensor dequant_gptq(torch::Tensor b_q_weight,
  45. torch::Tensor b_gptq_qzeros,
  46. torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
  47. int64_t bits, bool use_exllama);
  48. #ifndef USE_ROCM
  49. // Marlin
  50. torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  51. torch::Tensor& b_scales, torch::Tensor& workspace,
  52. int64_t size_m, int64_t size_n, int64_t size_k);
  53. torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  54. torch::Tensor& b_meta,
  55. torch::Tensor& b_scales,
  56. torch::Tensor& workspace,
  57. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  58. int64_t size_m, int64_t size_n,
  59. int64_t size_k);
  60. torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  61. torch::Tensor& b_scales, torch::Tensor& b_zeros,
  62. torch::Tensor& g_idx, torch::Tensor& perm,
  63. torch::Tensor& workspace,
  64. aphrodite::ScalarTypeTorchPtr const& b_q_type,
  65. int64_t size_m, int64_t size_n, int64_t size_k,
  66. bool is_k_full, bool has_zp,
  67. bool use_fp32_reduce);
  68. torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
  69. int64_t size_k, int64_t size_n,
  70. int64_t num_bits);
  71. torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
  72. int64_t size_n, int64_t num_bits);
  73. torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
  74. torch::Tensor& b_scales, torch::Tensor& workspace,
  75. int64_t num_bits, int64_t size_m, int64_t size_n,
  76. int64_t size_k);
  77. // GGUF
  78. torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
  79. int64_t n);
  80. torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
  81. int64_t type, int64_t row);
  82. torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
  83. int64_t row);
  84. // QuIP#
  85. at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
  86. const at::Tensor& CB);
  87. void decompress_e8p_origorder(torch::Tensor YIs, torch::Tensor CB,
  88. torch::Tensor& Y);
  89. bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
  90. void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
  91. torch::Tensor const& b, torch::Tensor const& a_scales,
  92. torch::Tensor const& b_scales,
  93. c10::optional<torch::Tensor> const& bias);
  94. void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
  95. torch::Tensor const& b,
  96. torch::Tensor const& a_scales,
  97. torch::Tensor const& b_scales,
  98. torch::Tensor const& azp_adj,
  99. c10::optional<torch::Tensor> const& azp,
  100. c10::optional<torch::Tensor> const& bias);
  101. torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
  102. torch::Tensor const& b_q_weight,
  103. torch::Tensor const& s_tok,
  104. torch::Tensor const& s_ch,
  105. torch::Tensor const& s_group,
  106. torch::Tensor& workspace, int64_t size_m,
  107. int64_t size_n, int64_t size_k);
  108. torch::Tensor fp_eXmY_linear_forward_cuda(int64_t EXPONENT, int64_t MANTISSA,
  109. torch::Tensor _in_feats,
  110. torch::Tensor _weights,
  111. torch::Tensor _scales,
  112. int64_t splitK = 1);
  113. #endif
  114. void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
  115. torch::Tensor const& scale);
  116. void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
  117. torch::Tensor& scales);
  118. // SqueezeLLM
  119. void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  120. torch::Tensor lookup_table);
  121. // FP8
  122. void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
  123. torch::Tensor const& scale);
  124. void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
  125. torch::Tensor& scale);
  126. void dynamic_per_token_scaled_fp8_quant(
  127. torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
  128. c10::optional<torch::Tensor> const& scale_ub);
  129. // flute
  130. #include <cuda_runtime.h>
  131. #include <torch/library.h>
  132. #include <torch/all.h>
  133. #include <ATen/ATen.h>
  134. #include <c10/cuda/CUDAGuard.h>
  135. #include "cute/numeric/integral_constant.hpp"
  136. template <typename SMs, typename T, typename TQ, typename T2, typename NumBits,
  137. typename GroupSize>
  138. void _qgemm(int M, int N, int K, int P, const T* const __restrict__ A,
  139. const TQ* const __restrict__ Q, T* __restrict__ D,
  140. const T* const __restrict__ S, const T* const __restrict__ QM,
  141. const T2* const __restrict__ QM2, void* __restrict__ workspace,
  142. const cudaStream_t stream);
  143. template <typename SMs, typename T, typename TQ, typename T2, typename NumBits,
  144. typename GroupSize>
  145. void _qgemm_raw(int M, int N, int K, int P, const T* const __restrict__ A,
  146. const TQ* const __restrict__ Q, T* __restrict__ D,
  147. const T* const __restrict__ S, const T* const __restrict__ QM,
  148. const T2* const __restrict__ QM2, void* __restrict__ workspace,
  149. const int template_id, const cudaStream_t stream);
  150. template <typename SMs, typename T, typename NumBits, typename GroupSize>
  151. void qgemm(const at::Tensor& input, const at::Tensor& weight,
  152. at::Tensor& output, const at::Tensor& scales,
  153. const at::Tensor& table, const at::Tensor& table2,
  154. at::Tensor& workspace, const cudaStream_t stream) {
  155. using namespace cute;
  156. using TQ = cute::uint16_t;
  157. using T2 = conditional_t<is_same_v<T, half_t>, __half2, __nv_bfloat162>;
  158. _qgemm<SMs, T, TQ, T2, NumBits,
  159. GroupSize>(output.size(0), // M
  160. output.size(1), // N
  161. input.size(1), // K
  162. weight.size(0), // P
  163. reinterpret_cast<const T*>(input.data_ptr()),
  164. reinterpret_cast<const TQ*>(weight.data_ptr()),
  165. reinterpret_cast<T*>(output.data_ptr()),
  166. reinterpret_cast<const T*>(scales.data_ptr()),
  167. reinterpret_cast<const T*>(table.data_ptr()),
  168. reinterpret_cast<const T2*>(table2.data_ptr()),
  169. reinterpret_cast<void*>(workspace.data_ptr()), stream);
  170. C10_CUDA_KERNEL_LAUNCH_CHECK();
  171. }
  172. template <typename SMs, typename T, typename NumBits, typename GroupSize>
  173. void qgemm_raw(const at::Tensor& input, const at::Tensor& weight,
  174. at::Tensor& output, const at::Tensor& scales,
  175. const at::Tensor& table, const at::Tensor& table2,
  176. at::Tensor& workspace, const int template_id,
  177. const cudaStream_t stream) {
  178. using namespace cute;
  179. using TQ = cute::uint16_t;
  180. using T2 = conditional_t<is_same_v<T, half_t>, __half2, __nv_bfloat162>;
  181. _qgemm_raw<SMs, T, TQ, T2, NumBits,
  182. GroupSize>(output.size(0), // M
  183. output.size(1), // N
  184. input.size(1), // K
  185. weight.size(0), // P
  186. reinterpret_cast<const T*>(input.data_ptr()),
  187. reinterpret_cast<const TQ*>(weight.data_ptr()),
  188. reinterpret_cast<T*>(output.data_ptr()),
  189. reinterpret_cast<const T*>(scales.data_ptr()),
  190. reinterpret_cast<const T*>(table.data_ptr()),
  191. reinterpret_cast<const T2*>(table2.data_ptr()),
  192. reinterpret_cast<void*>(workspace.data_ptr()),
  193. template_id, stream);
  194. C10_CUDA_KERNEL_LAUNCH_CHECK();
  195. }
  196. template <typename SMs>
  197. at::Tensor qgemm_simple(const at::Tensor& input, const at::Tensor& weight,
  198. const at::Tensor& scales, const at::Tensor& table,
  199. const at::Tensor& table2, at::Tensor& workspace,
  200. const cute::int64_t num_bits,
  201. const cute::int64_t group_size) {
  202. // Set the device of this function, primarily used when
  203. // we have multiple devices in the same process.
  204. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  205. // Get the current CUDA stream, primarily used
  206. // to make CUDA Graphs work.
  207. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  208. // Squash the batch dimensions of the input tensor with its
  209. // next-to-last dimensions.
  210. const auto input_sizes = input.sizes().vec();
  211. const auto input_2d = input.reshape({-1, input_sizes.back()});
  212. auto output = at::empty(
  213. {input_2d.size(0), scales.size(0)},
  214. at::TensorOptions().dtype(input_2d.dtype()).device(input_2d.device()));
  215. #define RUN_QGEMM(T, NUM_BITS, GROUP_SIZE) \
  216. do { \
  217. qgemm<SMs, T, cute::Int<NUM_BITS>, cute::Int<GROUP_SIZE> >( \
  218. input_2d, weight, output, scales, table, table2, workspace, stream); \
  219. } while (false)
  220. #define RUN_QGEMM_SWITCH_GROUP_SIZE(T, NUM_BITS) \
  221. do { \
  222. switch (group_size) { \
  223. case 32: \
  224. RUN_QGEMM(T, NUM_BITS, 32); \
  225. break; \
  226. case 64: \
  227. RUN_QGEMM(T, NUM_BITS, 64); \
  228. break; \
  229. case 128: \
  230. RUN_QGEMM(T, NUM_BITS, 128); \
  231. break; \
  232. case 256: \
  233. RUN_QGEMM(T, NUM_BITS, 256); \
  234. break; \
  235. default: \
  236. AT_ERROR("Unsupported `group_size`"); \
  237. } \
  238. } while (false)
  239. #define RUN_QGEMM_SWITCH_NUM_BITS_AND_GROUP_SIZE(T) \
  240. do { \
  241. switch (num_bits) { \
  242. case 2: \
  243. RUN_QGEMM_SWITCH_GROUP_SIZE(T, 2); \
  244. break; \
  245. case 3: \
  246. RUN_QGEMM_SWITCH_GROUP_SIZE(T, 3); \
  247. break; \
  248. case 4: \
  249. RUN_QGEMM_SWITCH_GROUP_SIZE(T, 4); \
  250. break; \
  251. default: \
  252. AT_ERROR("Unsupported `num_bits`"); \
  253. } \
  254. } while (false)
  255. AT_DISPATCH_SWITCH(
  256. input.scalar_type(), "qgemm_simple",
  257. AT_DISPATCH_CASE(at::ScalarType::Half, [&]() {
  258. RUN_QGEMM_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::half_t);
  259. return;
  260. }) AT_DISPATCH_CASE(at::ScalarType::BFloat16, [&]() {
  261. RUN_QGEMM_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::bfloat16_t);
  262. return;
  263. }));
  264. auto output_sizes = input_sizes;
  265. output_sizes.back() = scales.size(0);
  266. return output.reshape(output_sizes);
  267. }
  268. template <typename SMs>
  269. void qgemm_raw_simple(const at::Tensor& input, const at::Tensor& weight,
  270. at::Tensor& output, const at::Tensor& scales,
  271. const at::Tensor& table, const at::Tensor& table2,
  272. at::Tensor& workspace, const cute::int64_t num_bits,
  273. const cute::int64_t group_size,
  274. const cute::int64_t template_id) {
  275. // Set the device of this function, primarily used when
  276. // we have multiple devices in the same process.
  277. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  278. // Get the current CUDA stream, primarily used
  279. // to make CUDA Graphs work.
  280. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  281. #define RUN_QGEMM_RAW(T, NUM_BITS, GROUP_SIZE) \
  282. do { \
  283. qgemm_raw<SMs, T, cute::Int<NUM_BITS>, cute::Int<GROUP_SIZE> >( \
  284. input, weight, output, scales, table, table2, workspace, template_id, \
  285. stream); \
  286. } while (false)
  287. #define RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, NUM_BITS) \
  288. do { \
  289. switch (group_size) { \
  290. case 32: \
  291. RUN_QGEMM_RAW(T, NUM_BITS, 32); \
  292. break; \
  293. case 64: \
  294. RUN_QGEMM_RAW(T, NUM_BITS, 64); \
  295. break; \
  296. case 128: \
  297. RUN_QGEMM_RAW(T, NUM_BITS, 128); \
  298. break; \
  299. case 256: \
  300. RUN_QGEMM_RAW(T, NUM_BITS, 256); \
  301. break; \
  302. default: \
  303. AT_ERROR("Unsupported `group_size`"); \
  304. } \
  305. } while (false)
  306. #define RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(T) \
  307. do { \
  308. switch (num_bits) { \
  309. case 2: \
  310. RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 2); \
  311. break; \
  312. case 3: \
  313. RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 3); \
  314. break; \
  315. case 4: \
  316. RUN_QGEMM_RAW_SWITCH_GROUP_SIZE(T, 4); \
  317. break; \
  318. default: \
  319. AT_ERROR("Unsupported `num_bits`"); \
  320. } \
  321. } while (false)
  322. AT_DISPATCH_SWITCH(
  323. input.scalar_type(), "qgemm_raw_simple",
  324. AT_DISPATCH_CASE(at::ScalarType::Half, [&]() {
  325. RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::half_t);
  326. return;
  327. }) AT_DISPATCH_CASE(at::ScalarType::BFloat16, [&]() {
  328. RUN_QGEMM_RAW_SWITCH_NUM_BITS_AND_GROUP_SIZE(cute::bfloat16_t);
  329. return;
  330. }));
  331. }