123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- #include <cstdint>
- #include <torch/extension.h>
- void paged_attention_v1(
- torch::Tensor& out,
- torch::Tensor& query,
- torch::Tensor& key_cache,
- torch::Tensor& value_cache,
- int num_kv_heads,
- float scale,
- torch::Tensor& block_tables,
- torch::Tensor& context_lens,
- int block_size,
- int max_context_len,
- const c10::optional<torch::Tensor>& alibi_slopes,
- const bool enable_fp8_kv_cache);
- void paged_attention_v2(
- torch::Tensor& out,
- torch::Tensor& exp_sums,
- torch::Tensor& max_logits,
- torch::Tensor& tmp_out,
- torch::Tensor& query,
- torch::Tensor& key_cache,
- torch::Tensor& value_cache,
- int num_kv_heads,
- float scale,
- torch::Tensor& block_tables,
- torch::Tensor& context_lens,
- int block_size,
- int max_context_len,
- const c10::optional<torch::Tensor>& alibi_slopes,
- const bool enable_fp8_kv_cache);
- void rms_norm(
- torch::Tensor& out,
- torch::Tensor& input,
- torch::Tensor& weight,
- float epsilon);
- void fused_add_rms_norm(
- torch::Tensor& input,
- torch::Tensor& residual,
- torch::Tensor& weight,
- float epsilon);
- void rotary_embedding(
- torch::Tensor& positions,
- torch::Tensor& query,
- torch::Tensor& key,
- int head_size,
- torch::Tensor& cos_sin_cache,
- bool is_neox);
- void silu_and_mul(
- torch::Tensor& out,
- torch::Tensor& input);
- void gelu_new(
- torch::Tensor& out,
- torch::Tensor& input);
- void gelu_fast(
- torch::Tensor& out,
- torch::Tensor& input);
- // The AWQ kernels are only available on CUDA
- #ifndef USE_ROCM
- torch::Tensor awq_gemm(
- torch::Tensor _in_feats,
- torch::Tensor _kernel,
- torch::Tensor _scaling_factors,
- torch::Tensor _zeros,
- int split_k_iters);
- #endif
- void squeezellm_gemm(
- torch::Tensor vec,
- torch::Tensor mat,
- torch::Tensor mul,
- torch::Tensor lookup_table);
- torch::Tensor gptq_gemm(
- torch::Tensor a,
- torch::Tensor b_q_weight,
- torch::Tensor b_gptq_qzeros,
- torch::Tensor b_gptq_scales,
- torch::Tensor b_g_idx,
- bool use_exllama,
- int bit);
- void gptq_shuffle(
- torch::Tensor q_weight,
- torch::Tensor q_perm,
- int bit);
- void aphrodite_bincount(
- torch::Tensor src,
- torch::Tensor out);
-
|