ops.h 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #include <cstdint>
  2. #include <torch/extension.h>
  3. void paged_attention_v1(
  4. torch::Tensor& out,
  5. torch::Tensor& query,
  6. torch::Tensor& key_cache,
  7. torch::Tensor& value_cache,
  8. int num_kv_heads,
  9. float scale,
  10. torch::Tensor& block_tables,
  11. torch::Tensor& context_lens,
  12. int block_size,
  13. int max_context_len,
  14. const c10::optional<torch::Tensor>& alibi_slopes);
  15. void paged_attention_v2(
  16. torch::Tensor& out,
  17. torch::Tensor& exp_sums,
  18. torch::Tensor& max_logits,
  19. torch::Tensor& tmp_out,
  20. torch::Tensor& query,
  21. torch::Tensor& key_cache,
  22. torch::Tensor& value_cache,
  23. int num_kv_heads,
  24. float scale,
  25. torch::Tensor& block_tables,
  26. torch::Tensor& context_lens,
  27. int block_size,
  28. int max_context_len,
  29. const c10::optional<torch::Tensor>& alibi_slopes);
  30. void rms_norm(
  31. torch::Tensor& out,
  32. torch::Tensor& input,
  33. torch::Tensor& weight,
  34. float epsilon);
  35. void fused_add_rms_norm(
  36. torch::Tensor& input,
  37. torch::Tensor& residual,
  38. torch::Tensor& weight,
  39. float epsilon);
  40. void rotary_embedding(
  41. torch::Tensor& positions,
  42. torch::Tensor& query,
  43. torch::Tensor& key,
  44. int head_size,
  45. torch::Tensor& cos_sin_cache,
  46. bool is_neox);
  47. void silu_and_mul(
  48. torch::Tensor& out,
  49. torch::Tensor& input);
  50. void gelu_new(
  51. torch::Tensor& out,
  52. torch::Tensor& input);
  53. void gelu_fast(
  54. torch::Tensor& out,
  55. torch::Tensor& input);
  56. // The AWQ kernels are only available on CUDA
  57. #ifndef USE_ROCM
  58. torch::Tensor awq_gemm(
  59. torch::Tensor _in_feats,
  60. torch::Tensor _kernel,
  61. torch::Tensor _scaling_factors,
  62. torch::Tensor _zeros,
  63. int split_k_iters);
  64. #endif
  65. void squeezellm_gemm(
  66. torch::Tensor vec,
  67. torch::Tensor mat,
  68. torch::Tensor mul,
  69. torch::Tensor lookup_table);
  70. torch::Tensor gptq_gemm(
  71. torch::Tensor a,
  72. torch::Tensor b_q_weight,
  73. torch::Tensor b_gptq_qzeros,
  74. torch::Tensor b_gptq_scales,
  75. torch::Tensor b_g_idx,
  76. bool use_exllama,
  77. int bit);
  78. void gptq_shuffle(
  79. torch::Tensor q_weight,
  80. torch::Tensor q_perm,
  81. int bit);