torch_bindings.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. #include "cache.h"
  2. #include "cuda_utils.h"
  3. #include "ops.h"
  4. #include "core/registration.h"
  5. #include "quantization/quant_ops.h"
  6. #include "cute/numeric/integral_constant.hpp"
  7. #include <torch/library.h>
  8. // Note on op signatures:
  9. // The X_meta signatures are for the meta functions corresponding to op X.
  10. // They must be kept in sync with the signature for X. Generally, only
  11. // functions that return Tensors require a meta function.
  12. //
  13. // See the following links for detailed docs on op registration and function
  14. // schemas.
  15. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
  16. // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
  17. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  18. // Aphrodite custom ops
  19. // Attention ops
  20. // Compute the attention between an input query and the cached
  21. // keys/values using PagedAttention.
  22. ops.def(
  23. "paged_attention_v1("
  24. " Tensor! out, Tensor query, Tensor key_cache,"
  25. " Tensor value_cache, int num_kv_heads, float scale,"
  26. " Tensor block_tables, Tensor seq_lens, int block_size,"
  27. " int max_seq_len, Tensor? alibi_slopes,"
  28. " str kv_cache_dtype, float k_scale, float v_scale,"
  29. " int tp_rank, int blocksparse_local_blocks,"
  30. " int blocksparse_vert_stride, int blocksparse_block_size,"
  31. " int blocksparse_head_sliding_step) -> ()");
  32. ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
  33. // PagedAttention V2.
  34. ops.def(
  35. "paged_attention_v2("
  36. " Tensor! out, Tensor exp_sums, Tensor max_logits,"
  37. " Tensor tmp_out, Tensor query, Tensor key_cache,"
  38. " Tensor value_cache, int num_kv_heads, float scale,"
  39. " Tensor block_tables, Tensor seq_lens, int block_size,"
  40. " int max_seq_len, Tensor? alibi_slopes,"
  41. " str kv_cache_dtype, float k_scale, float v_scale,"
  42. " int tp_rank, int blocksparse_local_blocks,"
  43. " int blocksparse_vert_stride, int blocksparse_block_size,"
  44. " int blocksparse_head_sliding_step) -> ()");
  45. ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
  46. // Activation ops
  47. // Activation function used in SwiGLU.
  48. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  49. ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
  50. // Activation function used in GeGLU with `none` approximation.
  51. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  52. ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
  53. // Activation function used in GeGLU with `tanh` approximation.
  54. ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  55. ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
  56. // GELU implementation used in GPT-2.
  57. ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  58. ops.impl("gelu_new", torch::kCUDA, &gelu_new);
  59. // Approximate GELU implementation.
  60. ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  61. ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
  62. // Quick GELU implementation.
  63. ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  64. ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
  65. // prepare_inputs advance_step
  66. ops.def("advance_step", &advance_step);
  67. ops.impl("advance_step", torch::kCUDA, &advance_step);
  68. // Layernorm
  69. // Apply Root Mean Square (RMS) Normalization to the input tensor.
  70. ops.def(
  71. "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
  72. "()");
  73. ops.impl("rms_norm", torch::kCUDA, &rms_norm);
  74. // In-place fused Add and RMS Normalization.
  75. ops.def(
  76. "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
  77. "float epsilon) -> ()");
  78. ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
  79. // Rotary embedding
  80. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  81. ops.def(
  82. "rotary_embedding(Tensor positions, Tensor! query,"
  83. " Tensor! key, int head_size,"
  84. " Tensor cos_sin_cache, bool is_neox) -> ()");
  85. ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
  86. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  87. // (supports multiple loras).
  88. ops.def(
  89. "batched_rotary_embedding(Tensor positions, Tensor! query,"
  90. " Tensor! key, int head_size,"
  91. " Tensor cos_sin_cache, bool is_neox,"
  92. " int rot_dim,"
  93. " Tensor cos_sin_cache_offsets) -> ()");
  94. ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
  95. // Quantization ops
  96. #ifndef USE_ROCM
  97. // Quantized GEMM for AQLM.
  98. ops.def("aqlm_gemm", &aqlm_gemm);
  99. ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
  100. // Decompression method for AQLM.
  101. ops.def("aqlm_dequant", &aqlm_dequant);
  102. ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
  103. // Quantized GEMM for AWQ.
  104. ops.def("awq_gemm", &awq_gemm);
  105. ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
  106. // Dequantization for AWQ.
  107. ops.def("awq_dequantize", &awq_dequantize);
  108. ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
  109. // Dequantization for GGML.
  110. ops.def("ggml_dequantize", &ggml_dequantize);
  111. ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
  112. // mmvq kernel for GGML.
  113. ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
  114. ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
  115. // mmq kernel for GGML.
  116. ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
  117. ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
  118. // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  119. ops.def("marlin_gemm", &marlin_gemm);
  120. ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
  121. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  122. ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
  123. ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
  124. // gptq_marlin Optimized Quantized GEMM for GPTQ.
  125. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
  126. ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
  127. // gptq_marlin repack from GPTQ.
  128. ops.def("gptq_marlin_repack", &gptq_marlin_repack);
  129. ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
  130. // awq_marlin repack from AWQ.
  131. ops.def("awq_marlin_repack", &awq_marlin_repack);
  132. ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
  133. // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  134. ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
  135. ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
  136. // marlin_qqq_gemm for QQQ.
  137. ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
  138. ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
  139. // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
  140. // quantization.
  141. ops.def(
  142. "cutlass_scaled_mm(Tensor! out, Tensor a,"
  143. " Tensor b, Tensor a_scales,"
  144. " Tensor b_scales, Tensor? bias) -> ()");
  145. ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
  146. // Check if cutlass scaled_mm is supported for CUDA devices of the given
  147. // capability
  148. ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  149. ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
  150. &cutlass_scaled_mm_supports_fp8);
  151. // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  152. // quantization.
  153. ops.def(
  154. "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
  155. " Tensor b, Tensor a_scales,"
  156. " Tensor b_scales, Tensor azp_adj,"
  157. " Tensor? azp, Tensor? bias) -> ()");
  158. ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
  159. // QuIP# GEMV
  160. ops.def("quip_gemv", &e8p_mm_origorder);
  161. ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
  162. // QuIP# Decompress
  163. ops.def("quip_decompress", &decompress_e8p_origorder);
  164. ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
  165. // fp6_llm
  166. ops.def(
  167. "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA,"
  168. " Tensor _in_feats, Tensor _weights,"
  169. " Tensor _scales, int splitK=1) -> Tensor");
  170. ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
  171. &fp_eXmY_linear_forward_cuda);
  172. ops.def(
  173. "qgemm_simple_80(Tensor input, Tensor weight, Tensor scales, Tensor "
  174. "table, Tensor table2, Tensor(a!) workspace, int num_bits, int "
  175. "group_size) -> Tensor");
  176. ops.def(
  177. "qgemm_simple_86(Tensor input, Tensor weight, Tensor scales, Tensor "
  178. "table, Tensor table2, Tensor(a!) workspace, int num_bits, int "
  179. "group_size) -> Tensor");
  180. ops.def(
  181. "qgemm_simple_89(Tensor input, Tensor weight, Tensor scales, Tensor "
  182. "table, Tensor table2, Tensor(a!) workspace, int num_bits, int "
  183. "group_size) -> Tensor");
  184. ops.def(
  185. "qgemm_raw_simple_80(Tensor input, Tensor weight, Tensor(a!) output, "
  186. "Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int "
  187. "num_bits, int group_size, int template_id) -> ()");
  188. ops.def(
  189. "qgemm_raw_simple_86(Tensor input, Tensor weight, Tensor(a!) output, "
  190. "Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int "
  191. "num_bits, int group_size, int template_id) -> ()");
  192. ops.def(
  193. "qgemm_raw_simple_89(Tensor input, Tensor weight, Tensor(a!) output, "
  194. "Tensor scales, Tensor table, Tensor table2, Tensor(b!) workspace, int "
  195. "num_bits, int group_size, int template_id) -> ()");
  196. ops.impl("qgemm_simple_80", &qgemm_simple<cute::Int<108>>);
  197. ops.impl("qgemm_simple_86", &qgemm_simple<cute::Int<84>>);
  198. ops.impl("qgemm_simple_89", &qgemm_simple<cute::Int<128>>);
  199. ops.impl("qgemm_raw_simple_80", &qgemm_raw_simple<cute::Int<108>>);
  200. ops.impl("qgemm_raw_simple_86", &qgemm_raw_simple<cute::Int<84>>);
  201. ops.impl("qgemm_raw_simple_89", &qgemm_raw_simple<cute::Int<128>>);
  202. #endif
  203. // Quantized GEMM for GPTQ.
  204. ops.def("gptq_gemm", &gptq_gemm);
  205. ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
  206. // Post processing for GPTQ.
  207. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
  208. ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
  209. // Quantized GEMM for SqueezeLLM.
  210. ops.def(
  211. "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
  212. "lookup_table) -> ()");
  213. ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
  214. // Compute FP8 quantized tensor for given scaling factor.
  215. ops.def(
  216. "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
  217. ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
  218. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
  219. ops.def(
  220. "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  221. "()");
  222. ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
  223. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
  224. ops.def(
  225. "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
  226. "scale, Tensor? scale_ub) -> "
  227. "()");
  228. ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
  229. &dynamic_per_token_scaled_fp8_quant);
  230. // Aligning the number of tokens to be processed by each expert such
  231. // that it is divisible by the block size.
  232. ops.def(
  233. "moe_align_block_size(Tensor topk_ids, int num_experts,"
  234. " int block_size, Tensor! sorted_token_ids,"
  235. " Tensor! experts_ids,"
  236. " Tensor! num_tokens_post_pad) -> ()");
  237. ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
  238. // Compute int8 quantized tensor for given scaling factor.
  239. /*
  240. Implementation:
  241. void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  242. input, torch::Tensor const& scale);
  243. */
  244. ops.def(
  245. "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
  246. "()");
  247. ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
  248. // Compute int8 quantized tensor and scaling factor
  249. /*
  250. Implementation:
  251. void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  252. input, torch::Tensor& scales);
  253. */
  254. ops.def(
  255. "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  256. "()");
  257. ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
  258. &dynamic_scaled_int8_quant);
  259. // Mamba kernels
  260. ops.def(
  261. "selective_scan_fwd(Tensor! u, Tensor! delta,"
  262. "Tensor! A, Tensor! B, Tensor! C,"
  263. "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
  264. "bool delta_softplus,"
  265. "Tensor? index_, Tensor? x) -> Tensor[]");
  266. ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
  267. ops.def(
  268. "causal_conv1d_update(Tensor! x,"
  269. "Tensor! conv_state,"
  270. "Tensor! weight,"
  271. "Tensor? bias_,"
  272. "bool silu_activation) -> Tensor");
  273. ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
  274. ops.def(
  275. "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
  276. "Tensor? bias_,"
  277. "Tensor? seq_idx_,"
  278. "Tensor? seq_pos_idx_,"
  279. "Tensor? initial_states_,"
  280. "Tensor? final_states_out_,"
  281. "bool silu_activation) -> Tensor");
  282. ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
  283. }
  284. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  285. // Cache ops
  286. // Swap in (out) the cache blocks from src to dst.
  287. cache_ops.def(
  288. "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  289. cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
  290. // Copy the cache blocks from src to dst.
  291. cache_ops.def(
  292. "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
  293. "block_mapping) -> ()");
  294. cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
  295. // Reshape the key and value tensors and cache them.
  296. cache_ops.def(
  297. "reshape_and_cache(Tensor key, Tensor value,"
  298. " Tensor! key_cache, Tensor! value_cache,"
  299. " Tensor slot_mapping,"
  300. " str kv_cache_dtype,"
  301. " float k_scale, float v_scale) -> ()");
  302. cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
  303. // Reshape the key and value tensors and cache them.
  304. cache_ops.def(
  305. "reshape_and_cache_flash(Tensor key, Tensor value,"
  306. " Tensor! key_cache,"
  307. " Tensor! value_cache,"
  308. " Tensor slot_mapping,"
  309. " str kv_cache_dtype,"
  310. " float k_scale, float v_scale) -> ()");
  311. cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
  312. &reshape_and_cache_flash);
  313. // Convert the key and value cache to fp8 data type.
  314. cache_ops.def(
  315. "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
  316. "kv_cache_dtype) -> ()");
  317. cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
  318. }
  319. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  320. // Cuda utils
  321. // Gets the specified device attribute.
  322. cuda_utils.def("get_device_attribute", &get_device_attribute);
  323. cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
  324. // Gets the maximum shared memory per block device attribute.
  325. cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
  326. &get_max_shared_memory_per_block_device_attribute);
  327. cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
  328. torch::kCUDA,
  329. &get_max_shared_memory_per_block_device_attribute);
  330. }
  331. #ifndef USE_ROCM
  332. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  333. // Custom all-reduce kernels
  334. custom_ar.def("init_custom_ar", &init_custom_ar);
  335. custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  336. custom_ar.def("should_custom_ar", &should_custom_ar);
  337. custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
  338. custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  339. custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
  340. custom_ar.def(
  341. "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
  342. "()");
  343. custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
  344. custom_ar.def("dispose", &dispose);
  345. custom_ar.impl("dispose", torch::kCPU, &dispose);
  346. custom_ar.def("meta_size", &meta_size);
  347. custom_ar.impl("meta_size", torch::kCPU, &meta_size);
  348. custom_ar.def("register_buffer", &register_buffer);
  349. custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
  350. custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  351. custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
  352. &get_graph_buffer_ipc_meta);
  353. custom_ar.def("register_graph_buffers", &register_graph_buffers);
  354. custom_ar.impl("register_graph_buffers", torch::kCPU,
  355. &register_graph_buffers);
  356. }
  357. #endif
  358. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)