1
0

ops.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #include <cstdint>
  2. #include <torch/extension.h>
  3. void paged_attention_v1(
  4. torch::Tensor& out,
  5. torch::Tensor& query,
  6. torch::Tensor& key_cache,
  7. torch::Tensor& value_cache,
  8. int num_kv_heads,
  9. float scale,
  10. torch::Tensor& block_tables,
  11. torch::Tensor& context_lens,
  12. int block_size,
  13. int max_context_len,
  14. const c10::optional<torch::Tensor>& alibi_slopes,
  15. const bool enable_fp8_kv_cache);
  16. void paged_attention_v2(
  17. torch::Tensor& out,
  18. torch::Tensor& exp_sums,
  19. torch::Tensor& max_logits,
  20. torch::Tensor& tmp_out,
  21. torch::Tensor& query,
  22. torch::Tensor& key_cache,
  23. torch::Tensor& value_cache,
  24. int num_kv_heads,
  25. float scale,
  26. torch::Tensor& block_tables,
  27. torch::Tensor& context_lens,
  28. int block_size,
  29. int max_context_len,
  30. const c10::optional<torch::Tensor>& alibi_slopes,
  31. const bool enable_fp8_kv_cache);
  32. void rms_norm(
  33. torch::Tensor& out,
  34. torch::Tensor& input,
  35. torch::Tensor& weight,
  36. float epsilon);
  37. void fused_add_rms_norm(
  38. torch::Tensor& input,
  39. torch::Tensor& residual,
  40. torch::Tensor& weight,
  41. float epsilon);
  42. void rotary_embedding(
  43. torch::Tensor& positions,
  44. torch::Tensor& query,
  45. torch::Tensor& key,
  46. int head_size,
  47. torch::Tensor& cos_sin_cache,
  48. bool is_neox);
  49. void silu_and_mul(
  50. torch::Tensor& out,
  51. torch::Tensor& input);
  52. void gelu_new(
  53. torch::Tensor& out,
  54. torch::Tensor& input);
  55. void gelu_fast(
  56. torch::Tensor& out,
  57. torch::Tensor& input);
  58. // The AWQ kernels are only available on CUDA
  59. #ifndef USE_ROCM
  60. torch::Tensor awq_gemm(
  61. torch::Tensor _in_feats,
  62. torch::Tensor _kernel,
  63. torch::Tensor _scaling_factors,
  64. torch::Tensor _zeros,
  65. int split_k_iters);
  66. #endif
  67. void squeezellm_gemm(
  68. torch::Tensor vec,
  69. torch::Tensor mat,
  70. torch::Tensor mul,
  71. torch::Tensor lookup_table);
  72. torch::Tensor gptq_gemm(
  73. torch::Tensor a,
  74. torch::Tensor b_q_weight,
  75. torch::Tensor b_gptq_qzeros,
  76. torch::Tensor b_gptq_scales,
  77. torch::Tensor b_g_idx,
  78. bool use_exllama,
  79. int bit);
  80. void gptq_shuffle(
  81. torch::Tensor q_weight,
  82. torch::Tensor q_perm,
  83. int bit);
  84. void aphrodite_bincount(
  85. torch::Tensor src,
  86. torch::Tensor out);