|
@@ -37,8 +37,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
// PagedAttention V2.
|
|
// PagedAttention V2.
|
|
ops.def(
|
|
ops.def(
|
|
"paged_attention_v2("
|
|
"paged_attention_v2("
|
|
- " Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
|
|
|
- " Tensor tmp_out, Tensor query, Tensor key_cache,"
|
|
|
|
|
|
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
|
|
|
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
|
" int max_seq_len, Tensor? alibi_slopes,"
|
|
" int max_seq_len, Tensor? alibi_slopes,"
|
|
@@ -74,7 +74,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
|
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
|
|
|
|
|
// prepare_inputs advance_step
|
|
// prepare_inputs advance_step
|
|
- ops.def("advance_step", &advance_step);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "advance_step(int num_seqs, int num_queries, int block_size, "
|
|
|
|
+ "Tensor! input_tokens, Tensor sampled_token_ids, "
|
|
|
|
+ "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
|
|
|
+ "Tensor block_tables) -> ()");
|
|
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
|
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
|
|
|
|
|
// Layernorm
|
|
// Layernorm
|
|
@@ -111,60 +115,108 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
// Quantization ops
|
|
// Quantization ops
|
|
#ifndef USE_ROCM
|
|
#ifndef USE_ROCM
|
|
// Quantized GEMM for AQLM.
|
|
// Quantized GEMM for AQLM.
|
|
- ops.def("aqlm_gemm", &aqlm_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
|
|
|
+ "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
|
|
|
+ "-> Tensor");
|
|
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
|
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
|
|
|
|
|
// Decompression method for AQLM.
|
|
// Decompression method for AQLM.
|
|
- ops.def("aqlm_dequant", &aqlm_dequant);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "aqlm_dequant(Tensor codes, Tensor codebooks, "
|
|
|
|
+ "int[] codebook_partition_sizes) -> Tensor");
|
|
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
|
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
|
|
|
|
|
// Quantized GEMM for AWQ.
|
|
// Quantized GEMM for AWQ.
|
|
- ops.def("awq_gemm", &awq_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
|
|
|
+ "Tensor _zeros, int split_k_iters) -> Tensor");
|
|
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
|
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
|
|
|
|
|
// Dequantization for AWQ.
|
|
// Dequantization for AWQ.
|
|
- ops.def("awq_dequantize", &awq_dequantize);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
|
|
|
+ "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
|
|
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
|
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
|
|
|
|
|
// Dequantization for GGML.
|
|
// Dequantization for GGML.
|
|
- ops.def("ggml_dequantize", &ggml_dequantize);
|
|
|
|
|
|
+ ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
|
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
|
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
|
|
|
|
|
// mmvq kernel for GGML.
|
|
// mmvq kernel for GGML.
|
|
- ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
|
|
|
|
+ "-> Tensor");
|
|
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
|
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
|
|
|
|
|
// mmq kernel for GGML.
|
|
// mmq kernel for GGML.
|
|
- ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
|
|
|
|
|
|
+ ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
|
|
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
|
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
|
|
|
|
|
|
|
+ // Note about marlin kernel 'workspace' arguments:
|
|
|
|
+ // Technically these should be mutable since they are modified by the kernel.
|
|
|
|
+ // But since they are set back to zero once the kernel is finished we can
|
|
|
|
+ // hand wave and say that they have no net effect.
|
|
|
|
+ //
|
|
|
|
+ // The reason to mark 'workspace' as immutable is so that they don't interfere
|
|
|
|
+ // with using ScalarType arguments in the ops. If they are marked as mutable,
|
|
|
|
+ // pytorch throws an assert in
|
|
|
|
+ // 'torch._higher_order_ops._register_effectful_op' that prevents these
|
|
|
|
+ // kernels from being torch.compile'd.
|
|
|
|
+ // See the following document for more info on custom types and ops that use
|
|
|
|
+ // custom types:
|
|
|
|
+ // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
|
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
|
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
|
- ops.def("marlin_gemm", &marlin_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
|
|
+ "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
|
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
|
|
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
|
|
|
|
|
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
|
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
|
- ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
|
|
|
+ "Tensor b_scales, Tensor workspace, "
|
|
|
|
+ "__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
|
|
|
+ "int size_m, int size_n, int size_k) -> Tensor");
|
|
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
|
|
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
|
|
|
|
|
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
|
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
|
- ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
|
|
+ "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
|
|
|
+ "__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
|
|
|
+ "int size_m, int size_n, int size_k, bool is_k_full, "
|
|
|
|
+ "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
|
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
|
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
|
|
|
|
|
// gptq_marlin repack from GPTQ.
|
|
// gptq_marlin repack from GPTQ.
|
|
- ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
|
|
|
+ "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
|
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
|
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
|
|
|
+ ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
|
|
|
|
|
|
// awq_marlin repack from AWQ.
|
|
// awq_marlin repack from AWQ.
|
|
- ops.def("awq_marlin_repack", &awq_marlin_repack);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
|
|
|
+ "SymInt size_n, int num_bits) -> Tensor");
|
|
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
|
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
|
|
|
+ ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
|
|
|
|
|
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
|
- ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
|
|
+ "Tensor! workspace, int num_bits, int size_m, int size_n, "
|
|
|
|
+ "int size_k) -> Tensor");
|
|
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
|
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
|
|
|
|
|
#ifndef _WIN32
|
|
#ifndef _WIN32
|
|
// marlin_qqq_gemm for QQQ.
|
|
// marlin_qqq_gemm for QQQ.
|
|
- ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
|
|
|
+ "Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
|
|
|
+ "Tensor! workspace, int size_m, int size_n, "
|
|
|
|
+ "int size_k) -> Tensor");
|
|
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
|
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
|
|
|
|
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
|
@@ -177,9 +229,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
|
|
|
|
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
|
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
|
// capability
|
|
// capability
|
|
- ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
|
|
|
- ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
|
|
|
|
- &cutlass_scaled_mm_supports_fp8);
|
|
|
|
|
|
+ ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
|
|
|
+ ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
|
|
|
|
|
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
|
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
|
// quantization.
|
|
// quantization.
|
|
@@ -211,11 +262,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
#endif
|
|
#endif
|
|
|
|
|
|
// QuIP# GEMV
|
|
// QuIP# GEMV
|
|
- ops.def("quip_gemv", &e8p_mm_origorder);
|
|
|
|
|
|
+ ops.def("quip_gemv(Tensor A, Tensor B, Tensor CB) -> Tensor",
|
|
|
|
+ &e8p_mm_origorder);
|
|
ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
|
|
ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
|
|
|
|
|
|
// QuIP# Decompress
|
|
// QuIP# Decompress
|
|
- ops.def("quip_decompress", &decompress_e8p_origorder);
|
|
|
|
|
|
+ ops.def("quip_decompress(Tensor YIs, Tensor CB, Tensor Y) -> ()",
|
|
|
|
+ &decompress_e8p_origorder);
|
|
ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
|
|
ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
|
|
|
|
|
|
// fp6_llm
|
|
// fp6_llm
|
|
@@ -227,31 +280,73 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
&fp_eXmY_linear_forward_cuda);
|
|
&fp_eXmY_linear_forward_cuda);
|
|
|
|
|
|
// Sampling Kernels
|
|
// Sampling Kernels
|
|
- ops.def("sampling_from_probs", &sampling_from_probs);
|
|
|
|
|
|
+ ops.def(
|
|
|
|
+ "sampling_from_probs(Tensor probs, Tensor uniform_samples, bool "
|
|
|
|
+ "deterministic) -> Tensor",
|
|
|
|
+ &sampling_from_probs);
|
|
ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
|
|
ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
|
|
- ops.def("top_k_sampling_from_probs", &top_k_sampling_from_probs);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "top_k_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
|
|
|
|
+ " Tensor? maybe_top_k_arr, int top_k_val,"
|
|
|
|
+ " bool deterministic) -> Tensor[]",
|
|
|
|
+ &top_k_sampling_from_probs);
|
|
ops.impl("top_k_sampling_from_probs", torch::kCUDA,
|
|
ops.impl("top_k_sampling_from_probs", torch::kCUDA,
|
|
&top_k_sampling_from_probs);
|
|
&top_k_sampling_from_probs);
|
|
- ops.def("min_p_sampling_from_probs", &min_p_sampling_from_probs);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
|
|
|
|
+ " Tensor? maybe_min_p_arr, float min_p_val,"
|
|
|
|
+ " bool deterministic) -> Tensor[]",
|
|
|
|
+ &min_p_sampling_from_probs);
|
|
ops.impl("min_p_sampling_from_probs", torch::kCUDA,
|
|
ops.impl("min_p_sampling_from_probs", torch::kCUDA,
|
|
&min_p_sampling_from_probs);
|
|
&min_p_sampling_from_probs);
|
|
- ops.def("top_p_sampling_from_probs", &top_p_sampling_from_probs);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
|
|
|
|
+ " Tensor? maybe_top_p_arr, float top_p_val,"
|
|
|
|
+ " bool deterministic) -> Tensor[]",
|
|
|
|
+ &top_p_sampling_from_probs);
|
|
ops.impl("top_p_sampling_from_probs", torch::kCUDA,
|
|
ops.impl("top_p_sampling_from_probs", torch::kCUDA,
|
|
&top_p_sampling_from_probs);
|
|
&top_p_sampling_from_probs);
|
|
- ops.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
|
|
|
|
+ " Tensor? maybe_top_k_arr, float top_k_val,"
|
|
|
|
+ " Tensor? maybe_top_p_arr, float top_p_val,"
|
|
|
|
+ " bool deterministic) -> Tensor[]",
|
|
|
|
+ &top_k_top_p_sampling_from_probs);
|
|
ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
|
|
ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
|
|
&top_k_top_p_sampling_from_probs);
|
|
&top_k_top_p_sampling_from_probs);
|
|
- ops.def("top_k_renorm_prob", &top_k_renorm_prob);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "top_k_renorm_prob(Tensor probs, Tensor? maybe_top_k_arr, int top_k_val) "
|
|
|
|
+ "-> Tensor",
|
|
|
|
+ &top_k_renorm_prob);
|
|
ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
|
|
ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
|
|
- ops.def("top_p_renorm_prob", &top_p_renorm_prob);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "top_p_renorm_prob(Tensor probs, Tensor? maybe_top_p_arr, float "
|
|
|
|
+ "top_p_val) "
|
|
|
|
+ "-> Tensor",
|
|
|
|
+ &top_p_renorm_prob);
|
|
ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
|
|
ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
|
|
- ops.def("top_k_mask_logits", &top_k_mask_logits);
|
|
|
|
|
|
+
|
|
|
|
+ ops.def(
|
|
|
|
+ "top_k_mask_logits(Tensor logits, Tensor? maybe_top_k_arr, int "
|
|
|
|
+ "top_k_val) -> Tensor",
|
|
|
|
+ &top_k_mask_logits);
|
|
ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
|
ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
|
|
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
// Quantized GEMM for GPTQ.
|
|
// Quantized GEMM for GPTQ.
|
|
- ops.def("gptq_gemm", &gptq_gemm);
|
|
|
|
|
|
+ // Note: even though the C++ inferred schema is correct for this op, it seems
|
|
|
|
+ // to prevent the meta function registry.
|
|
|
|
+ ops.def(
|
|
|
|
+ "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
|
|
|
+ "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
|
|
|
+ "-> Tensor");
|
|
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
|
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
|
|
|
|
|
// Post processing for GPTQ.
|
|
// Post processing for GPTQ.
|
|
@@ -277,8 +372,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
|
|
|
|
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
|
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
|
ops.def(
|
|
ops.def(
|
|
- "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
|
|
|
|
- "scale, Tensor? scale_ub) -> "
|
|
|
|
|
|
+ "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
|
|
|
|
+ "Tensor! scale, Tensor? scale_ub) -> "
|
|
"()");
|
|
"()");
|
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
|
&dynamic_per_token_scaled_fp8_quant);
|
|
&dynamic_per_token_scaled_fp8_quant);
|
|
@@ -321,7 +416,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
"Tensor! A, Tensor! B, Tensor! C,"
|
|
"Tensor! A, Tensor! B, Tensor! C,"
|
|
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
|
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
|
"bool delta_softplus,"
|
|
"bool delta_softplus,"
|
|
- "Tensor? index_, Tensor? x) -> Tensor[]");
|
|
|
|
|
|
+ "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
|
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
|
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
|
|
|
|
|
ops.def(
|
|
ops.def(
|
|
@@ -353,8 +448,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|
|
|
|
|
// Copy the cache blocks from src to dst.
|
|
// Copy the cache blocks from src to dst.
|
|
cache_ops.def(
|
|
cache_ops.def(
|
|
- "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
|
|
|
- "block_mapping) -> ()");
|
|
|
|
|
|
+ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
|
|
|
+ "Tensor block_mapping) -> ()");
|
|
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
|
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
|
|
|
|
|
// Reshape the key and value tensors and cache them.
|
|
// Reshape the key and value tensors and cache them.
|
|
@@ -379,8 +474,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|
|
|
|
|
// Convert the key and value cache to fp8 data type.
|
|
// Convert the key and value cache to fp8 data type.
|
|
cache_ops.def(
|
|
cache_ops.def(
|
|
- "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
|
|
|
|
- "kv_cache_dtype) -> ()");
|
|
|
|
|
|
+ "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
|
|
|
+ "str kv_cache_dtype) -> ()");
|
|
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
|
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -388,24 +483,27 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
|
// Cuda utils
|
|
// Cuda utils
|
|
|
|
|
|
// Gets the specified device attribute.
|
|
// Gets the specified device attribute.
|
|
- cuda_utils.def("get_device_attribute", &get_device_attribute);
|
|
|
|
- cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
|
|
|
|
|
|
+ cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
|
|
|
+ cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
|
|
|
|
|
// Gets the maximum shared memory per block device attribute.
|
|
// Gets the maximum shared memory per block device attribute.
|
|
- cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
|
|
|
- &get_max_shared_memory_per_block_device_attribute);
|
|
|
|
|
|
+ cuda_utils.def(
|
|
|
|
+ "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
|
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
|
- torch::kCUDA,
|
|
|
|
&get_max_shared_memory_per_block_device_attribute);
|
|
&get_max_shared_memory_per_block_device_attribute);
|
|
}
|
|
}
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
#ifndef USE_ROCM
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|
// Custom all-reduce kernels
|
|
// Custom all-reduce kernels
|
|
- custom_ar.def("init_custom_ar", &init_custom_ar);
|
|
|
|
|
|
+ custom_ar.def(
|
|
|
|
+ "init_custom_ar(Tensor meta, Tensor rank_data, "
|
|
|
|
+ "str[] handles, int[] offsets, int rank, "
|
|
|
|
+ "bool full_nvlink) -> int");
|
|
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
|
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
|
-
|
|
|
|
- custom_ar.def("should_custom_ar", &should_custom_ar);
|
|
|
|
|
|
+ custom_ar.def(
|
|
|
|
+ "should_custom_ar(Tensor inp, int max_size, int world_size, "
|
|
|
|
+ "bool full_nvlink) -> bool");
|
|
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
|
|
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
|
|
|
|
|
|
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
|
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
|
@@ -417,21 +515,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
|
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
|
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
|
|
|
|
|
custom_ar.def("dispose", &dispose);
|
|
custom_ar.def("dispose", &dispose);
|
|
- custom_ar.impl("dispose", torch::kCPU, &dispose);
|
|
|
|
|
|
|
|
custom_ar.def("meta_size", &meta_size);
|
|
custom_ar.def("meta_size", &meta_size);
|
|
- custom_ar.impl("meta_size", torch::kCPU, &meta_size);
|
|
|
|
|
|
|
|
- custom_ar.def("register_buffer", ®ister_buffer);
|
|
|
|
|
|
+ custom_ar.def(
|
|
|
|
+ "register_buffer(int fa, Tensor t, str[] handles, "
|
|
|
|
+ "int[] offsets) -> ()");
|
|
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
|
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
|
|
|
|
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
|
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
|
- custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
|
|
|
|
- &get_graph_buffer_ipc_meta);
|
|
|
|
|
|
|
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
|
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
|
- custom_ar.impl("register_graph_buffers", torch::kCPU,
|
|
|
|
- ®ister_graph_buffers);
|
|
|
|
}
|
|
}
|
|
#endif
|
|
#endif
|
|
|
|
|