torch_bindings.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. #include "cache.h"
  2. #include "cuda_utils.h"
  3. #include "ops.h"
  4. #include "core/registration.h"
  5. #include "quantization/quant_ops.h"
  6. #include <torch/library.h>
  7. // Note on op signatures:
  8. // The X_meta signatures are for the meta functions corresponding to op X.
  9. // They must be kept in sync with the signature for X. Generally, only
  10. // functions that return Tensors require a meta function.
  11. //
  12. // See the following links for detailed docs on op registration and function
  13. // schemas.
  14. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
  15. // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
  16. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  17. // Aphrodite custom ops
  18. // Attention ops
  19. // Compute the attention between an input query and the cached
  20. // keys/values using PagedAttention.
  21. ops.def(
  22. "paged_attention_v1("
  23. " Tensor! out, Tensor query, Tensor key_cache,"
  24. " Tensor value_cache, int num_kv_heads, float scale,"
  25. " Tensor block_tables, Tensor seq_lens, int block_size,"
  26. " int max_seq_len, Tensor? alibi_slopes,"
  27. " str kv_cache_dtype, float k_scale, float v_scale,"
  28. " int tp_rank, int blocksparse_local_blocks,"
  29. " int blocksparse_vert_stride, int blocksparse_block_size,"
  30. " int blocksparse_head_sliding_step) -> ()");
  31. ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
  32. // PagedAttention V2.
  33. ops.def(
  34. "paged_attention_v2("
  35. " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
  36. " Tensor! tmp_out, Tensor query, Tensor key_cache,"
  37. " Tensor value_cache, int num_kv_heads, float scale,"
  38. " Tensor block_tables, Tensor seq_lens, int block_size,"
  39. " int max_seq_len, Tensor? alibi_slopes,"
  40. " str kv_cache_dtype, float k_scale, float v_scale,"
  41. " int tp_rank, int blocksparse_local_blocks,"
  42. " int blocksparse_vert_stride, int blocksparse_block_size,"
  43. " int blocksparse_head_sliding_step) -> ()");
  44. ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
  45. // Activation ops
  46. // Activation function used in SwiGLU.
  47. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  48. ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
  49. // Activation function used in GeGLU with `none` approximation.
  50. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  51. ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
  52. // Activation function used in GeGLU with `tanh` approximation.
  53. ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  54. ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
  55. // GELU implementation used in GPT-2.
  56. ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  57. ops.impl("gelu_new", torch::kCUDA, &gelu_new);
  58. // Approximate GELU implementation.
  59. ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  60. ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
  61. // Quick GELU implementation.
  62. ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  63. ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
  64. // prepare_inputs advance_step
  65. ops.def(
  66. "advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
  67. "Tensor! input_tokens, Tensor sampled_token_ids, "
  68. "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
  69. "Tensor block_tables) -> ()");
  70. ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
  71. ops.def(
  72. "advance_step_flashinfer("
  73. " int num_seqs, int num_queries, int block_size,"
  74. " Tensor! input_tokens, Tensor sampled_token_ids,"
  75. " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
  76. " Tensor block_tables, Tensor! paged_kv_indices,"
  77. " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
  78. " Tensor! block_table_bounds"
  79. ") -> ()");
  80. ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
  81. // Layernorm
  82. // Apply Root Mean Square (RMS) Normalization to the input tensor.
  83. ops.def(
  84. "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
  85. "()");
  86. ops.impl("rms_norm", torch::kCUDA, &rms_norm);
  87. // In-place fused Add and RMS Normalization.
  88. ops.def(
  89. "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
  90. "float epsilon) -> ()");
  91. ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
  92. // Rotary embedding
  93. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  94. ops.def(
  95. "rotary_embedding(Tensor positions, Tensor! query,"
  96. " Tensor! key, int head_size,"
  97. " Tensor cos_sin_cache, bool is_neox) -> ()");
  98. ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
  99. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  100. // (supports multiple loras).
  101. ops.def(
  102. "batched_rotary_embedding(Tensor positions, Tensor! query,"
  103. " Tensor! key, int head_size,"
  104. " Tensor cos_sin_cache, bool is_neox,"
  105. " int rot_dim,"
  106. " Tensor cos_sin_cache_offsets) -> ()");
  107. ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
  108. // Quantization ops
  109. #ifndef USE_ROCM
  110. // Quantized GEMM for AQLM.
  111. ops.def(
  112. "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
  113. "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
  114. "-> Tensor");
  115. ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
  116. // Decompression method for AQLM.
  117. ops.def(
  118. "aqlm_dequant(Tensor codes, Tensor codebooks, "
  119. "int[] codebook_partition_sizes) -> Tensor");
  120. ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
  121. // Quantized GEMM for AWQ.
  122. ops.def(
  123. "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
  124. "Tensor _zeros, int split_k_iters) -> Tensor");
  125. ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
  126. // Dequantization for AWQ.
  127. ops.def(
  128. "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
  129. "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
  130. ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
  131. // Dequantization for GGML.
  132. ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
  133. ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
  134. // mmvq kernel for GGML.
  135. ops.def(
  136. "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
  137. "-> Tensor");
  138. ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
  139. // mmq kernel for GGML.
  140. ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
  141. ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
  142. // Note about marlin kernel 'workspace' arguments:
  143. // Technically these should be mutable since they are modified by the kernel.
  144. // But since they are set back to zero once the kernel is finished we can
  145. // hand wave and say that they have no net effect.
  146. //
  147. // The reason to mark 'workspace' as immutable is so that they don't interfere
  148. // with using ScalarType arguments in the ops. If they are marked as mutable,
  149. // pytorch throws an assert in
  150. // 'torch._higher_order_ops._register_effectful_op' that prevents these
  151. // kernels from being torch.compile'd.
  152. // See the following document for more info on custom types and ops that use
  153. // custom types:
  154. // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
  155. // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  156. ops.def(
  157. "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
  158. "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
  159. ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
  160. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  161. ops.def(
  162. "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
  163. "Tensor b_scales, Tensor workspace, "
  164. "__torch__.torch.classes._core_C.ScalarType b_q_type, "
  165. "int size_m, int size_n, int size_k) -> Tensor");
  166. ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
  167. // gptq_marlin Optimized Quantized GEMM for GPTQ.
  168. ops.def(
  169. "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
  170. "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
  171. "__torch__.torch.classes._core_C.ScalarType b_q_type, "
  172. "int size_m, int size_n, int size_k, bool is_k_full, "
  173. "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
  174. ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
  175. // gptq_marlin repack from GPTQ.
  176. ops.def(
  177. "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
  178. "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
  179. ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
  180. ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
  181. // awq_marlin repack from AWQ.
  182. ops.def(
  183. "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
  184. "SymInt size_n, int num_bits) -> Tensor");
  185. ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
  186. ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
  187. // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
  188. ops.def(
  189. "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
  190. "Tensor! workspace, int num_bits, int size_m, int size_n, "
  191. "int size_k) -> Tensor");
  192. ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
  193. #ifndef _WIN32
  194. // marlin_qqq_gemm for QQQ.
  195. ops.def(
  196. "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
  197. "Tensor s_tok, Tensor s_ch, Tensor s_group, "
  198. "Tensor! workspace, int size_m, int size_n, "
  199. "int size_k) -> Tensor");
  200. ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
  201. // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
  202. // quantization.
  203. ops.def(
  204. "cutlass_scaled_mm(Tensor! out, Tensor a,"
  205. " Tensor b, Tensor a_scales,"
  206. " Tensor b_scales, Tensor? bias) -> ()");
  207. ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
  208. // Check if cutlass scaled_mm is supported for CUDA devices of the given
  209. // capability
  210. ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
  211. ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  212. // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  213. // quantization.
  214. ops.def(
  215. "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
  216. " Tensor b, Tensor a_scales,"
  217. " Tensor b_scales, Tensor azp_adj,"
  218. " Tensor? azp, Tensor? bias) -> ()");
  219. ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
  220. // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
  221. ops.def("machete_supported_schedules", &machete::supported_schedules);
  222. ops.def(
  223. "machete_gemm(Tensor A, Tensor B,"
  224. " __torch__.torch.classes._core_C.ScalarType btype,"
  225. " Tensor? scales, Tensor? zeros, int? group_size,"
  226. " Tensor? C, float? alpha, float? beta, str? schedule)"
  227. "-> Tensor");
  228. ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
  229. ops.def(
  230. "machete_prepack_B(Tensor B,"
  231. " __torch__.torch.classes._core_C.ScalarType btype)"
  232. "-> Tensor");
  233. ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
  234. ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
  235. ops.impl("permute_cols", torch::kCUDA, &permute_cols);
  236. #endif
  237. // QuIP# GEMV
  238. ops.def("quip_gemv(Tensor A, Tensor B, Tensor CB) -> Tensor",
  239. &e8p_mm_origorder);
  240. ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
  241. // QuIP# Decompress
  242. ops.def("quip_decompress(Tensor YIs, Tensor CB, Tensor Y) -> ()",
  243. &decompress_e8p_origorder);
  244. ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
  245. // fp6_llm
  246. ops.def(
  247. "fp_eXmY_linear_forward_cuda(int EXPONENT, int MANTISSA,"
  248. " Tensor _in_feats, Tensor _weights,"
  249. " Tensor _scales, int splitK=1) -> Tensor");
  250. ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
  251. &fp_eXmY_linear_forward_cuda);
  252. // Sampling Kernels
  253. ops.def(
  254. "sampling_from_probs(Tensor probs, Tensor uniform_samples, bool "
  255. "deterministic) -> Tensor",
  256. &sampling_from_probs);
  257. ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
  258. ops.def(
  259. "top_k_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  260. " Tensor? maybe_top_k_arr, int top_k_val,"
  261. " bool deterministic) -> Tensor[]",
  262. &top_k_sampling_from_probs);
  263. ops.impl("top_k_sampling_from_probs", torch::kCUDA,
  264. &top_k_sampling_from_probs);
  265. ops.def(
  266. "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  267. " Tensor? maybe_min_p_arr, float min_p_val,"
  268. " bool deterministic) -> Tensor[]",
  269. &min_p_sampling_from_probs);
  270. ops.impl("min_p_sampling_from_probs", torch::kCUDA,
  271. &min_p_sampling_from_probs);
  272. ops.def(
  273. "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  274. " Tensor? maybe_top_p_arr, float top_p_val,"
  275. " bool deterministic) -> Tensor[]",
  276. &top_p_sampling_from_probs);
  277. ops.impl("top_p_sampling_from_probs", torch::kCUDA,
  278. &top_p_sampling_from_probs);
  279. ops.def(
  280. "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
  281. " Tensor? maybe_top_k_arr, float top_k_val,"
  282. " Tensor? maybe_top_p_arr, float top_p_val,"
  283. " bool deterministic) -> Tensor[]",
  284. &top_k_top_p_sampling_from_probs);
  285. ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
  286. &top_k_top_p_sampling_from_probs);
  287. ops.def(
  288. "top_k_renorm_prob(Tensor probs, Tensor? maybe_top_k_arr, int top_k_val) "
  289. "-> Tensor",
  290. &top_k_renorm_prob);
  291. ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
  292. ops.def(
  293. "top_p_renorm_prob(Tensor probs, Tensor? maybe_top_p_arr, float "
  294. "top_p_val) "
  295. "-> Tensor",
  296. &top_p_renorm_prob);
  297. ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
  298. ops.def(
  299. "top_k_mask_logits(Tensor logits, Tensor? maybe_top_k_arr, int "
  300. "top_k_val) -> Tensor",
  301. &top_k_mask_logits);
  302. ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
  303. #endif
  304. // Quantized GEMM for GPTQ.
  305. // Note: even though the C++ inferred schema is correct for this op, it seems
  306. // to prevent the meta function registry.
  307. ops.def(
  308. "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
  309. "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
  310. "-> Tensor");
  311. ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
  312. // Post processing for GPTQ.
  313. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
  314. ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
  315. // Quantized GEMM for SqueezeLLM.
  316. ops.def(
  317. "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
  318. "lookup_table) -> ()");
  319. ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
  320. // Compute FP8 quantized tensor for given scaling factor.
  321. ops.def(
  322. "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
  323. ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
  324. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
  325. ops.def(
  326. "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  327. "()");
  328. ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
  329. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
  330. ops.def(
  331. "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
  332. "Tensor! scale, Tensor? scale_ub) -> "
  333. "()");
  334. ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
  335. &dynamic_per_token_scaled_fp8_quant);
  336. // Aligning the number of tokens to be processed by each expert such
  337. // that it is divisible by the block size.
  338. ops.def(
  339. "moe_align_block_size(Tensor topk_ids, int num_experts,"
  340. " int block_size, Tensor! sorted_token_ids,"
  341. " Tensor! experts_ids,"
  342. " Tensor! num_tokens_post_pad) -> ()");
  343. ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
  344. // Compute int8 quantized tensor for given scaling factor.
  345. /*
  346. Implementation:
  347. void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  348. input, torch::Tensor const& scale);
  349. */
  350. ops.def(
  351. "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
  352. "Tensor? azp) -> ()");
  353. ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
  354. // Compute int8 quantized tensor and scaling factor
  355. /*
  356. Implementation:
  357. void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const&
  358. input, torch::Tensor& scales);
  359. */
  360. ops.def(
  361. "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
  362. "Tensor!? azp) -> ()");
  363. ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
  364. &dynamic_scaled_int8_quant);
  365. #ifndef USE_ROCM
  366. // Mamba kernels
  367. ops.def(
  368. "selective_scan_fwd(Tensor! u, Tensor! delta,"
  369. "Tensor! A, Tensor! B, Tensor! C,"
  370. "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
  371. "bool delta_softplus,"
  372. "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
  373. ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
  374. ops.def(
  375. "causal_conv1d_update(Tensor! x,"
  376. "Tensor! conv_state,"
  377. "Tensor! weight,"
  378. "Tensor? bias,"
  379. "bool silu_activation,"
  380. "Tensor? conv_state_indices) -> Tensor");
  381. ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
  382. ops.def(
  383. "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
  384. "Tensor? bias_,"
  385. "Tensor? seq_idx_,"
  386. "Tensor? initial_states_,"
  387. "Tensor? final_states_out_,"
  388. "bool silu_activation) -> Tensor");
  389. ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
  390. #endif
  391. }
  392. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  393. // Cache ops
  394. // Swap in (out) the cache blocks from src to dst.
  395. cache_ops.def(
  396. "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  397. cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
  398. // Copy the cache blocks from src to dst.
  399. cache_ops.def(
  400. "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
  401. "Tensor block_mapping) -> ()");
  402. cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
  403. // Reshape the key and value tensors and cache them.
  404. cache_ops.def(
  405. "reshape_and_cache(Tensor key, Tensor value,"
  406. " Tensor! key_cache, Tensor! value_cache,"
  407. " Tensor slot_mapping,"
  408. " str kv_cache_dtype,"
  409. " float k_scale, float v_scale) -> ()");
  410. cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
  411. // Reshape the key and value tensors and cache them.
  412. cache_ops.def(
  413. "reshape_and_cache_flash(Tensor key, Tensor value,"
  414. " Tensor! key_cache,"
  415. " Tensor! value_cache,"
  416. " Tensor slot_mapping,"
  417. " str kv_cache_dtype,"
  418. " float k_scale, float v_scale) -> ()");
  419. cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
  420. &reshape_and_cache_flash);
  421. // Convert the key and value cache to fp8 data type.
  422. cache_ops.def(
  423. "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
  424. "str kv_cache_dtype) -> ()");
  425. cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
  426. }
  427. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  428. // Cuda utils
  429. // Gets the specified device attribute.
  430. cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
  431. cuda_utils.impl("get_device_attribute", &get_device_attribute);
  432. // Gets the maximum shared memory per block device attribute.
  433. cuda_utils.def(
  434. "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
  435. cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
  436. &get_max_shared_memory_per_block_device_attribute);
  437. }
  438. #ifndef USE_ROCM
  439. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  440. // Custom all-reduce kernels
  441. custom_ar.def(
  442. "init_custom_ar(Tensor meta, Tensor rank_data, "
  443. "str[] handles, int[] offsets, int rank, "
  444. "bool full_nvlink) -> int");
  445. custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  446. custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  447. custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
  448. custom_ar.def(
  449. "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
  450. "()");
  451. custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
  452. custom_ar.def("dispose", &dispose);
  453. custom_ar.def("meta_size", &meta_size);
  454. custom_ar.def(
  455. "register_buffer(int fa, Tensor t, str[] handles, "
  456. "int[] offsets) -> ()");
  457. custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
  458. custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  459. custom_ar.def("register_graph_buffers", &register_graph_buffers);
  460. }
  461. #endif
  462. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)