torch_bindings.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #include "cache.h"
  2. #include "cuda_utils.h"
  3. #include "ops.h"
  4. #include "core/registration.h"
  5. #include "quantization/quant_ops.h"
  6. #include "flash_attn/flash_api.h"
  7. #include <torch/library.h>
  8. // Note on op signatures:
  9. // The X_meta signatures are for the meta functions corresponding to op X.
  10. // They must be kept in sync with the signature for X. Generally, only
  11. // functions that return Tensors require a meta function.
  12. //
  13. // See the following links for detailed docs on op registration and function
  14. // schemas.
  15. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
  16. // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
  17. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  18. // Aphrodite custom ops
  19. // Attention ops
  20. // Compute the attention between an input query and the cached
  21. // keys/values using PagedAttention.
  22. ops.def(
  23. "paged_attention_v1("
  24. " Tensor! out, Tensor query, Tensor key_cache,"
  25. " Tensor value_cache, int num_kv_heads, float scale,"
  26. " Tensor block_tables, Tensor seq_lens, int block_size,"
  27. " int max_seq_len, Tensor? alibi_slopes,"
  28. " str kv_cache_dtype, float k_scale, float v_scale,"
  29. " int tp_rank, int blocksparse_local_blocks,"
  30. " int blocksparse_vert_stride, int blocksparse_block_size,"
  31. " int blocksparse_head_sliding_step) -> ()");
  32. ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
  33. // PagedAttention V2.
  34. ops.def(
  35. "paged_attention_v2("
  36. " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
  37. " Tensor! tmp_out, Tensor query, Tensor key_cache,"
  38. " Tensor value_cache, int num_kv_heads, float scale,"
  39. " Tensor block_tables, Tensor seq_lens, int block_size,"
  40. " int max_seq_len, Tensor? alibi_slopes,"
  41. " str kv_cache_dtype, float k_scale, float v_scale,"
  42. " int tp_rank, int blocksparse_local_blocks,"
  43. " int blocksparse_vert_stride, int blocksparse_block_size,"
  44. " int blocksparse_head_sliding_step) -> ()");
  45. ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
  46. // Activation ops
  47. // Activation function used in SwiGLU.
  48. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  49. ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
  50. // Activation function used in GeGLU with `none` approximation.
  51. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  52. ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
  53. // Activation function used in GeGLU with `tanh` approximation.
  54. ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  55. ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
  56. // GELU implementation used in GPT-2.
  57. ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  58. ops.impl("gelu_new", torch::kCUDA, &gelu_new);
  59. // Approximate GELU implementation.
  60. ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  61. ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
  62. // Quick GELU implementation.
  63. ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  64. ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
  65. // prepare_inputs advance_step
  66. ops.def(
  67. "advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
  68. "Tensor! input_tokens, Tensor sampled_token_ids, "
  69. "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
  70. "Tensor block_tables) -> ()");
  71. ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
  72. ops.def(
  73. "advance_step_flashinfer("
  74. " int num_seqs, int num_queries, int block_size,"
  75. " Tensor! input_tokens, Tensor sampled_token_ids,"
  76. " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
  77. " Tensor block_tables, Tensor! paged_kv_indices,"
  78. " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
  79. " Tensor! block_table_bounds"
  80. ") -> ()");
  81. ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
  82. // Layernorm
  83. // Apply Root Mean Square (RMS) Normalization to the input tensor.
  84. ops.def(
  85. "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
  86. "()");
  87. ops.impl("rms_norm", torch::kCUDA, &rms_norm);
  88. // In-place fused Add and RMS Normalization.
  89. ops.def(
  90. "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
  91. "float epsilon) -> ()");
  92. ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
  93. // Rotary embedding
  94. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  95. ops.def(
  96. "rotary_embedding(Tensor positions, Tensor! query,"
  97. " Tensor! key, int head_size,"
  98. " Tensor cos_sin_cache, bool is_neox) -> ()");
  99. ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
  100. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  101. // (supports multiple loras).
  102. ops.def(
  103. "batched_rotary_embedding(Tensor positions, Tensor! query,"
  104. " Tensor! key, int head_size,"
  105. " Tensor cos_sin_cache, bool is_neox,"
  106. " int rot_dim,"
  107. " Tensor cos_sin_cache_offsets) -> ()");
  108. ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
  109. // Quantization ops
  110. #ifndef USE_ROCM
  111. // Quantized GEMM for AQLM.
  112. ops.def(
  113. "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
  114. "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
  115. "-> Tensor");
  116. ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
  117. // Decompression method for AQLM.
  118. ops.def(
  119. "aqlm_dequant(Tensor codes, Tensor codebooks, "
  120. "int[] codebook_partition_sizes) -> Tensor");
  121. ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
  122. // Quantized GEMM for AWQ.
  123. ops.def(
  124. "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
  125. "Tensor _zeros, int split_k_iters) -> Tensor");
  126. ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
  127. // Dequantization for AWQ.
  128. ops.def(
  129. "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
  130. "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
  131. ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
  132. // Dequantization for GGML.
  133. ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
  134. ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
  135. // mmvq kernel for GGML.
  136. ops.def(
  137. "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
  138. "-> Tensor");
  139. ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
  140. // mmq kernel for GGML.
  141. ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
  142. ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
  143. // Note about marlin kernel 'workspace' arguments:
  144. // Technically these should be mutable since they are modified by the kernel.
  145. // But since they are set back to zero once the kernel is finished we can
  146. // hand wave and say that they have no net effect.
  147. //
  148. // The reason to mark 'workspace' as immutable is so that they don't interfere
  149. // with using ScalarType arguments in the ops. If they are marked as mutable,
  150. // pytorch throws an assert in
  151. // 'torch._higher_order_ops._register_effectful_op' that prevents these
  152. // kernels from being torch.compile'd.
  153. // See the following document for more info on custom types and ops that use
  154. // custom types:
  155. // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
  156. // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  157. ops.def(
  158. "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
  159. "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
  160. ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
  161. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  162. ops.def(
  163. "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
  164. "Tensor b_scales, Tensor workspace, "
  165. "__torch__.torch.classes._core_C.ScalarType b_q_type, "
  166. "int size_m, int size_n, int size_k) -> Tensor");
  167. ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
  168. // gptq_marlin Optimized Quantized GEMM for GPTQ.
  169. ops.def(
  170. "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
  171. "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
  172. "__torch__.torch.classes._core_C.ScalarType b_q_type, "
  173. "int size_m, int size_n, int size_k, bool is_k_full, "
  174. "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
  175. ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
  176. // gptq_marlin repack from GPTQ.
  177. ops.def(
  178. "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
  179. "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
  180. ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
  181. ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
  182. // awq_marlin repack from AWQ.
  183. ops.def(
  184. "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
  185. "SymInt size_n, int num_bits) -> Tensor");
  186. ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
  187. ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
  188. // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  189. ops.def(
  190. "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
  191. "Tensor! workspace, int num_bits, int size_m, int size_n, "
  192. "int size_k) -> Tensor");
  193. ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
  194. #ifndef _WIN32
  195. // marlin_qqq_gemm for QQQ.
  196. ops.def(
  197. "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
  198. "Tensor s_tok, Tensor s_ch, Tensor s_group, "
  199. "Tensor! workspace, int size_m, int size_n, "
  200. "int size_k) -> Tensor");
  201. ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
  202. // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
  203. // quantization.
  204. ops.def(
  205. "cutlass_scaled_mm(Tensor! out, Tensor a,"
  206. " Tensor b, Tensor a_scales,"
  207. " Tensor b_scales, Tensor? bias) -> ()");
  208. ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
  209. // Check if cutlass scaled_mm is supported for CUDA devices of the given
  210. // capability
  211. ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
  212. ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  213. // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  214. // quantization.
  215. ops.def(
  216. "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
  217. " Tensor b, Tensor a_scales,"
  218. " Tensor b_scales, Tensor azp_adj,"
  219. " Tensor? azp, Tensor? bias) -> ()");
  220. ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
  221. // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
  222. ops.def("machete_supported_schedules", &machete::supported_schedules);
  223. ops.def(
  224. "machete_gemm(Tensor A, Tensor B,"
  225. " __torch__.torch.classes._core_C.ScalarType btype,"
  226. " Tensor? scales, Tensor? zeros, int? group_size,"
  227. " Tensor? C, float? alpha, float? beta, str? schedule)"
  228. "-> Tensor");
  229. ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
  230. ops.def(
  231. "machete_prepack_B(Tensor B,"
  232. " __torch__.torch.classes._core_C.ScalarType btype)"
  233. "-> Tensor");
  234. ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
  235. ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
  236. ops.impl("permute_cols", torch::kCUDA, &permute_cols);
  237. #endif
  238. // QuIP# GEMV
  239. ops.def("quip_gemv(Tensor A, Tensor B, Tensor CB) -> Tensor",
  240. &e8p_mm_origorder);
  241. ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
  242. // QuIP# Decompress
  243. ops.def("quip_decompress(Tensor YIs, Tensor CB, Tensor Y) -> ()",
  244. &decompress_e8p_origorder);
  245. ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
  246. // fp6_llm
  247. ops.def(
  248. "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA,"
  249. " Tensor _in_feats, Tensor _weights,"
  250. " Tensor _scales, int splitK=1) -> Tensor");
  251. ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
  252. &fp_eXmY_linear_forward_cuda);
  253. // Sampling Kernels
  254. ops.def(
  255. "sampling_from_probs(Tensor probs, Tensor uniform_samples, bool "
  256. "deterministic) -> Tensor",
  257. &sampling_from_probs);
  258. ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
  259. ops.def(
  260. "top_k_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  261. " Tensor? maybe_top_k_arr, int top_k_val,"
  262. " bool deterministic) -> Tensor[]",
  263. &top_k_sampling_from_probs);
  264. ops.impl("top_k_sampling_from_probs", torch::kCUDA,
  265. &top_k_sampling_from_probs);
  266. ops.def(
  267. "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  268. " Tensor? maybe_min_p_arr, float min_p_val,"
  269. " bool deterministic) -> Tensor[]",
  270. &min_p_sampling_from_probs);
  271. ops.impl("min_p_sampling_from_probs", torch::kCUDA,
  272. &min_p_sampling_from_probs);
  273. ops.def(
  274. "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  275. " Tensor? maybe_top_p_arr, float top_p_val,"
  276. " bool deterministic) -> Tensor[]",
  277. &top_p_sampling_from_probs);
  278. ops.impl("top_p_sampling_from_probs", torch::kCUDA,
  279. &top_p_sampling_from_probs);
  280. ops.def(
  281. "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  282. " Tensor? maybe_top_k_arr, float top_k_val,"
  283. " Tensor? maybe_top_p_arr, float top_p_val,"
  284. " bool deterministic) -> Tensor[]",
  285. &top_k_top_p_sampling_from_probs);
  286. ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
  287. &top_k_top_p_sampling_from_probs);
  288. ops.def(
  289. "top_k_renorm_prob(Tensor probs, Tensor? maybe_top_k_arr, int top_k_val) "
  290. "-> Tensor",
  291. &top_k_renorm_prob);
  292. ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
  293. ops.def(
  294. "top_p_renorm_prob(Tensor probs, Tensor? maybe_top_p_arr, float "
  295. "top_p_val) "
  296. "-> Tensor",
  297. &top_p_renorm_prob);
  298. ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
  299. ops.def(
  300. "top_k_mask_logits(Tensor logits, Tensor? maybe_top_k_arr, int "
  301. "top_k_val) -> Tensor",
  302. &top_k_mask_logits);
  303. ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
  304. #endif
  305. // Quantized GEMM for GPTQ.
  306. // Note: even though the C++ inferred schema is correct for this op, it seems
  307. // to prevent the meta function registry.
  308. ops.def(
  309. "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
  310. "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
  311. "-> Tensor");
  312. ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
  313. // Post processing for GPTQ.
  314. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
  315. ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
  316. // Quantized GEMM for SqueezeLLM.
  317. ops.def(
  318. "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
  319. "lookup_table) -> ()");
  320. ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
  321. // Compute FP8 quantized tensor for given scaling factor.
  322. ops.def(
  323. "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
  324. ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
  325. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
  326. ops.def(
  327. "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  328. "()");
  329. ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
  330. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
  331. ops.def(
  332. "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
  333. "Tensor! scale, Tensor? scale_ub) -> "
  334. "()");
  335. ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
  336. &dynamic_per_token_scaled_fp8_quant);
  337. // Aligning the number of tokens to be processed by each expert such
  338. // that it is divisible by the block size.
  339. ops.def(
  340. "moe_align_block_size(Tensor topk_ids, int num_experts,"
  341. " int block_size, Tensor! sorted_token_ids,"
  342. " Tensor! experts_ids,"
  343. " Tensor! num_tokens_post_pad) -> ()");
  344. ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
  345. // Compute int8 quantized tensor for given scaling factor.
  346. /*
  347. Implementation:
  348. void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  349. input, torch::Tensor const& scale);
  350. */
  351. ops.def(
  352. "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
  353. "Tensor? azp) -> ()");
  354. ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
  355. // Compute int8 quantized tensor and scaling factor
  356. /*
  357. Implementation:
  358. void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  359. input, torch::Tensor& scales);
  360. */
  361. ops.def(
  362. "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
  363. "Tensor!? azp) -> ()");
  364. ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
  365. &dynamic_scaled_int8_quant);
  366. #ifndef USE_ROCM
  367. // Mamba kernels
  368. ops.def(
  369. "selective_scan_fwd(Tensor! u, Tensor! delta,"
  370. "Tensor! A, Tensor! B, Tensor! C,"
  371. "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
  372. "bool delta_softplus,"
  373. "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
  374. ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
  375. ops.def(
  376. "causal_conv1d_update(Tensor! x,"
  377. "Tensor! conv_state,"
  378. "Tensor! weight,"
  379. "Tensor? bias,"
  380. "bool silu_activation,"
  381. "Tensor? conv_state_indices) -> Tensor");
  382. ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
  383. ops.def(
  384. "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
  385. "Tensor? bias_,"
  386. "Tensor? seq_idx_,"
  387. "Tensor? initial_states_,"
  388. "Tensor? final_states_out_,"
  389. "bool silu_activation) -> Tensor");
  390. ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
  391. ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
  392. "float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
  393. "float softcap, bool return_softmax, Generator? gen) -> Tensor[]");
  394. ops.impl("fwd", torch::kCUDA, &mha_fwd);
  395. ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
  396. "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? block_table, Tensor? alibi_slopes, "
  397. "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
  398. "bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
  399. "Generator? gen) -> Tensor[]");
  400. ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
  401. ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
  402. "Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? block_table, Tensor? alibi_slopes, "
  403. "Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
  404. "float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
  405. ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
  406. #endif
  407. }
  408. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  409. // Cache ops
  410. // Swap in (out) the cache blocks from src to dst.
  411. cache_ops.def(
  412. "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  413. cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
  414. // Copy the cache blocks from src to dst.
  415. cache_ops.def(
  416. "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
  417. "Tensor block_mapping) -> ()");
  418. cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
  419. // Reshape the key and value tensors and cache them.
  420. cache_ops.def(
  421. "reshape_and_cache(Tensor key, Tensor value,"
  422. " Tensor! key_cache, Tensor! value_cache,"
  423. " Tensor slot_mapping,"
  424. " str kv_cache_dtype,"
  425. " float k_scale, float v_scale) -> ()");
  426. cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
  427. // Reshape the key and value tensors and cache them.
  428. cache_ops.def(
  429. "reshape_and_cache_flash(Tensor key, Tensor value,"
  430. " Tensor! key_cache,"
  431. " Tensor! value_cache,"
  432. " Tensor slot_mapping,"
  433. " str kv_cache_dtype,"
  434. " float k_scale, float v_scale) -> ()");
  435. cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
  436. &reshape_and_cache_flash);
  437. // Convert the key and value cache to fp8 data type.
  438. cache_ops.def(
  439. "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
  440. "str kv_cache_dtype) -> ()");
  441. cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
  442. }
  443. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  444. // Cuda utils
  445. // Gets the specified device attribute.
  446. cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
  447. cuda_utils.impl("get_device_attribute", &get_device_attribute);
  448. // Gets the maximum shared memory per block device attribute.
  449. cuda_utils.def(
  450. "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
  451. cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
  452. &get_max_shared_memory_per_block_device_attribute);
  453. }
  454. #ifndef USE_ROCM
  455. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  456. // Custom all-reduce kernels
  457. custom_ar.def(
  458. "init_custom_ar(Tensor meta, Tensor rank_data, "
  459. "str[] handles, int[] offsets, int rank, "
  460. "bool full_nvlink) -> int");
  461. custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  462. custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  463. custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
  464. custom_ar.def(
  465. "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
  466. "()");
  467. custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
  468. custom_ar.def("dispose", &dispose);
  469. custom_ar.def("meta_size", &meta_size);
  470. custom_ar.def(
  471. "register_buffer(int fa, Tensor t, str[] handles, "
  472. "int[] offsets) -> ()");
  473. custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
  474. custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  475. custom_ar.def("register_graph_buffers", &register_graph_buffers);
  476. }
  477. #endif
  478. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)