torch_bindings.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. #include "cache.h"
  2. #include "cuda_utils.h"
  3. #include "ops.h"
  4. #include "core/registration.h"
  5. #include "quantization/quant_ops.h"
  6. #include <torch/library.h>
  7. // Note on op signatures:
  8. // The X_meta signatures are for the meta functions corresponding to op X.
  9. // They must be kept in sync with the signature for X. Generally, only
  10. // functions that return Tensors require a meta function.
  11. //
  12. // See the following links for detailed docs on op registration and function
  13. // schemas.
  14. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
  15. // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
  16. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  17. // Aphrodite custom ops
  18. // Attention ops
  19. // Compute the attention between an input query and the cached
  20. // keys/values using PagedAttention.
  21. ops.def(
  22. "paged_attention_v1("
  23. " Tensor! out, Tensor query, Tensor key_cache,"
  24. " Tensor value_cache, int num_kv_heads, float scale,"
  25. " Tensor block_tables, Tensor seq_lens, int block_size,"
  26. " int max_seq_len, Tensor? alibi_slopes,"
  27. " str kv_cache_dtype, float k_scale, float v_scale,"
  28. " int tp_rank, int blocksparse_local_blocks,"
  29. " int blocksparse_vert_stride, int blocksparse_block_size,"
  30. " int blocksparse_head_sliding_step) -> ()");
  31. ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
  32. // PagedAttention V2.
  33. ops.def(
  34. "paged_attention_v2("
  35. " Tensor! out, Tensor exp_sums, Tensor max_logits,"
  36. " Tensor tmp_out, Tensor query, Tensor key_cache,"
  37. " Tensor value_cache, int num_kv_heads, float scale,"
  38. " Tensor block_tables, Tensor seq_lens, int block_size,"
  39. " int max_seq_len, Tensor? alibi_slopes,"
  40. " str kv_cache_dtype, float k_scale, float v_scale,"
  41. " int tp_rank, int blocksparse_local_blocks,"
  42. " int blocksparse_vert_stride, int blocksparse_block_size,"
  43. " int blocksparse_head_sliding_step) -> ()");
  44. ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
  45. // Activation ops
  46. // Activation function used in SwiGLU.
  47. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  48. ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
  49. // Activation function used in GeGLU with `none` approximation.
  50. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  51. ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
  52. // Activation function used in GeGLU with `tanh` approximation.
  53. ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  54. ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
  55. // GELU implementation used in GPT-2.
  56. ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  57. ops.impl("gelu_new", torch::kCUDA, &gelu_new);
  58. // Approximate GELU implementation.
  59. ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  60. ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
  61. // Quick GELU implementation.
  62. ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  63. ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
  64. // prepare_inputs advance_step
  65. ops.def("advance_step", &advance_step);
  66. ops.impl("advance_step", torch::kCUDA, &advance_step);
  67. // Layernorm
  68. // Apply Root Mean Square (RMS) Normalization to the input tensor.
  69. ops.def(
  70. "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
  71. "()");
  72. ops.impl("rms_norm", torch::kCUDA, &rms_norm);
  73. // In-place fused Add and RMS Normalization.
  74. ops.def(
  75. "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
  76. "float epsilon) -> ()");
  77. ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
  78. // Rotary embedding
  79. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  80. ops.def(
  81. "rotary_embedding(Tensor positions, Tensor! query,"
  82. " Tensor! key, int head_size,"
  83. " Tensor cos_sin_cache, bool is_neox) -> ()");
  84. ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
  85. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  86. // (supports multiple loras).
  87. ops.def(
  88. "batched_rotary_embedding(Tensor positions, Tensor! query,"
  89. " Tensor! key, int head_size,"
  90. " Tensor cos_sin_cache, bool is_neox,"
  91. " int rot_dim,"
  92. " Tensor cos_sin_cache_offsets) -> ()");
  93. ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
  94. // Quantization ops
  95. #ifndef USE_ROCM
  96. // Quantized GEMM for AQLM.
  97. ops.def("aqlm_gemm", &aqlm_gemm);
  98. ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
  99. // Decompression method for AQLM.
  100. ops.def("aqlm_dequant", &aqlm_dequant);
  101. ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
  102. // Quantized GEMM for AWQ.
  103. ops.def("awq_gemm", &awq_gemm);
  104. ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
  105. // Dequantization for AWQ.
  106. ops.def("awq_dequantize", &awq_dequantize);
  107. ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
  108. // Dequantization for GGML.
  109. ops.def("ggml_dequantize", &ggml_dequantize);
  110. ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
  111. // mmvq kernel for GGML.
  112. ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
  113. ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
  114. // mmq kernel for GGML.
  115. ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
  116. ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
  117. // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  118. ops.def("marlin_gemm", &marlin_gemm);
  119. ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
  120. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  121. ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
  122. ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
  123. // gptq_marlin Optimized Quantized GEMM for GPTQ.
  124. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
  125. ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
  126. // gptq_marlin repack from GPTQ.
  127. ops.def("gptq_marlin_repack", &gptq_marlin_repack);
  128. ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
  129. // awq_marlin repack from AWQ.
  130. ops.def("awq_marlin_repack", &awq_marlin_repack);
  131. ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
  132. // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  133. ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
  134. ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
  135. #ifndef _WIN32
  136. // marlin_qqq_gemm for QQQ.
  137. ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
  138. ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
  139. // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
  140. // quantization.
  141. ops.def(
  142. "cutlass_scaled_mm(Tensor! out, Tensor a,"
  143. " Tensor b, Tensor a_scales,"
  144. " Tensor b_scales, Tensor? bias) -> ()");
  145. ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
  146. // Check if cutlass scaled_mm is supported for CUDA devices of the given
  147. // capability
  148. ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  149. ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
  150. &cutlass_scaled_mm_supports_fp8);
  151. // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  152. // quantization.
  153. ops.def(
  154. "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
  155. " Tensor b, Tensor a_scales,"
  156. " Tensor b_scales, Tensor azp_adj,"
  157. " Tensor? azp, Tensor? bias) -> ()");
  158. ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
  159. // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
  160. ops.def("machete_supported_schedules", &machete::supported_schedules);
  161. ops.def(
  162. "machete_gemm(Tensor A, Tensor B,"
  163. " __torch__.torch.classes._core_C.ScalarType btype,"
  164. " Tensor? scales, Tensor? zeros, int? group_size,"
  165. " Tensor? C, float? alpha, float? beta, str? schedule)"
  166. "-> Tensor");
  167. ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
  168. ops.def(
  169. "machete_prepack_B(Tensor B,"
  170. " __torch__.torch.classes._core_C.ScalarType btype)"
  171. "-> Tensor");
  172. ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
  173. ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
  174. ops.impl("permute_cols", torch::kCUDA, &permute_cols);
  175. #endif
  176. // QuIP# GEMV
  177. ops.def("quip_gemv", &e8p_mm_origorder);
  178. ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
  179. // QuIP# Decompress
  180. ops.def("quip_decompress", &decompress_e8p_origorder);
  181. ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
  182. // fp6_llm
  183. ops.def(
  184. "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA,"
  185. " Tensor _in_feats, Tensor _weights,"
  186. " Tensor _scales, int splitK=1) -> Tensor");
  187. ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
  188. &fp_eXmY_linear_forward_cuda);
  189. // Sampling Kernels
  190. ops.def("sampling_from_probs", &sampling_from_probs);
  191. ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
  192. ops.def("top_k_sampling_from_probs", &top_k_sampling_from_probs);
  193. ops.impl("top_k_sampling_from_probs", torch::kCUDA,
  194. &top_k_sampling_from_probs);
  195. ops.def("min_p_sampling_from_probs", &min_p_sampling_from_probs);
  196. ops.impl("min_p_sampling_from_probs", torch::kCUDA,
  197. &min_p_sampling_from_probs);
  198. ops.def("top_p_sampling_from_probs", &top_p_sampling_from_probs);
  199. ops.impl("top_p_sampling_from_probs", torch::kCUDA,
  200. &top_p_sampling_from_probs);
  201. ops.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs);
  202. ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
  203. &top_k_top_p_sampling_from_probs);
  204. ops.def("top_k_renorm_prob", &top_k_renorm_prob);
  205. ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
  206. ops.def("top_p_renorm_prob", &top_p_renorm_prob);
  207. ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
  208. ops.def("top_k_mask_logits", &top_k_mask_logits);
  209. ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
  210. #endif
  211. // Quantized GEMM for GPTQ.
  212. ops.def("gptq_gemm", &gptq_gemm);
  213. ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
  214. // Post processing for GPTQ.
  215. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
  216. ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
  217. // Quantized GEMM for SqueezeLLM.
  218. ops.def(
  219. "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
  220. "lookup_table) -> ()");
  221. ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
  222. // Compute FP8 quantized tensor for given scaling factor.
  223. ops.def(
  224. "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
  225. ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
  226. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
  227. ops.def(
  228. "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  229. "()");
  230. ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
  231. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
  232. ops.def(
  233. "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
  234. "scale, Tensor? scale_ub) -> "
  235. "()");
  236. ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
  237. &dynamic_per_token_scaled_fp8_quant);
  238. // Aligning the number of tokens to be processed by each expert such
  239. // that it is divisible by the block size.
  240. ops.def(
  241. "moe_align_block_size(Tensor topk_ids, int num_experts,"
  242. " int block_size, Tensor! sorted_token_ids,"
  243. " Tensor! experts_ids,"
  244. " Tensor! num_tokens_post_pad) -> ()");
  245. ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
  246. // Compute int8 quantized tensor for given scaling factor.
  247. /*
  248. Implementation:
  249. void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  250. input, torch::Tensor const& scale);
  251. */
  252. ops.def(
  253. "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
  254. "()");
  255. ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
  256. // Compute int8 quantized tensor and scaling factor
  257. /*
  258. Implementation:
  259. void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  260. input, torch::Tensor& scales);
  261. */
  262. ops.def(
  263. "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  264. "()");
  265. ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
  266. &dynamic_scaled_int8_quant);
  267. #ifndef USE_ROCM
  268. // Mamba kernels
  269. ops.def(
  270. "selective_scan_fwd(Tensor! u, Tensor! delta,"
  271. "Tensor! A, Tensor! B, Tensor! C,"
  272. "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
  273. "bool delta_softplus,"
  274. "Tensor? index_, Tensor? x) -> Tensor[]");
  275. ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
  276. ops.def(
  277. "causal_conv1d_update(Tensor! x,"
  278. "Tensor! conv_state,"
  279. "Tensor! weight,"
  280. "Tensor? bias_,"
  281. "bool silu_activation) -> Tensor");
  282. ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
  283. ops.def(
  284. "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
  285. "Tensor? bias_,"
  286. "Tensor? seq_idx_,"
  287. "Tensor? seq_pos_idx_,"
  288. "Tensor? initial_states_,"
  289. "Tensor? final_states_out_,"
  290. "bool silu_activation) -> Tensor");
  291. ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
  292. #endif
  293. }
  294. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  295. // Cache ops
  296. // Swap in (out) the cache blocks from src to dst.
  297. cache_ops.def(
  298. "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  299. cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
  300. // Copy the cache blocks from src to dst.
  301. cache_ops.def(
  302. "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
  303. "block_mapping) -> ()");
  304. cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
  305. // Reshape the key and value tensors and cache them.
  306. cache_ops.def(
  307. "reshape_and_cache(Tensor key, Tensor value,"
  308. " Tensor! key_cache, Tensor! value_cache,"
  309. " Tensor slot_mapping,"
  310. " str kv_cache_dtype,"
  311. " float k_scale, float v_scale) -> ()");
  312. cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
  313. // Reshape the key and value tensors and cache them.
  314. cache_ops.def(
  315. "reshape_and_cache_flash(Tensor key, Tensor value,"
  316. " Tensor! key_cache,"
  317. " Tensor! value_cache,"
  318. " Tensor slot_mapping,"
  319. " str kv_cache_dtype,"
  320. " float k_scale, float v_scale) -> ()");
  321. cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
  322. &reshape_and_cache_flash);
  323. // Convert the key and value cache to fp8 data type.
  324. cache_ops.def(
  325. "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
  326. "kv_cache_dtype) -> ()");
  327. cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
  328. }
  329. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  330. // Cuda utils
  331. // Gets the specified device attribute.
  332. cuda_utils.def("get_device_attribute", &get_device_attribute);
  333. cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
  334. // Gets the maximum shared memory per block device attribute.
  335. cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
  336. &get_max_shared_memory_per_block_device_attribute);
  337. cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
  338. torch::kCUDA,
  339. &get_max_shared_memory_per_block_device_attribute);
  340. }
  341. #ifndef USE_ROCM
  342. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  343. // Custom all-reduce kernels
  344. custom_ar.def("init_custom_ar", &init_custom_ar);
  345. custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  346. custom_ar.def("should_custom_ar", &should_custom_ar);
  347. custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
  348. custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  349. custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
  350. custom_ar.def(
  351. "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
  352. "()");
  353. custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
  354. custom_ar.def("dispose", &dispose);
  355. custom_ar.impl("dispose", torch::kCPU, &dispose);
  356. custom_ar.def("meta_size", &meta_size);
  357. custom_ar.impl("meta_size", torch::kCPU, &meta_size);
  358. custom_ar.def("register_buffer", &register_buffer);
  359. custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
  360. custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  361. custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
  362. &get_graph_buffer_ipc_meta);
  363. custom_ar.def("register_graph_buffers", &register_graph_buffers);
  364. custom_ar.impl("register_graph_buffers", torch::kCPU,
  365. &register_graph_buffers);
  366. }
  367. #endif
  368. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)