ops.h 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #pragma once
  2. #include <torch/extension.h>
  3. void paged_attention_v1(
  4. torch::Tensor& out,
  5. torch::Tensor& query,
  6. torch::Tensor& key_cache,
  7. torch::Tensor& value_cache,
  8. int num_kv_heads,
  9. float scale,
  10. torch::Tensor& block_tables,
  11. torch::Tensor& context_lens,
  12. int block_size,
  13. int max_context_len,
  14. const c10::optional<torch::Tensor>& alibi_slopes);
  15. void paged_attention_v2(
  16. torch::Tensor& out,
  17. torch::Tensor& exp_sums,
  18. torch::Tensor& max_logits,
  19. torch::Tensor& tmp_out,
  20. torch::Tensor& query,
  21. torch::Tensor& key_cache,
  22. torch::Tensor& value_cache,
  23. int num_kv_heads,
  24. float scale,
  25. torch::Tensor& block_tables,
  26. torch::Tensor& context_lens,
  27. int block_size,
  28. int max_context_len,
  29. const c10::optional<torch::Tensor>& alibi_slopes);
  30. void rms_norm(
  31. torch::Tensor& out,
  32. torch::Tensor& input,
  33. torch::Tensor& weight,
  34. float epsilon);
  35. void fused_add_rms_norm(
  36. torch::Tensor& input,
  37. torch::Tensor& residual,
  38. torch::Tensor& weight,
  39. float epsilon);
  40. void rotary_embedding(
  41. torch::Tensor& positions,
  42. torch::Tensor& query,
  43. torch::Tensor& key,
  44. int head_size,
  45. torch::Tensor& cos_sin_cache,
  46. bool is_neox);
  47. void silu_and_mul(
  48. torch::Tensor& out,
  49. torch::Tensor& input);
  50. void gelu_new(
  51. torch::Tensor& out,
  52. torch::Tensor& input);
  53. void gelu_fast(
  54. torch::Tensor& out,
  55. torch::Tensor& input);
  56. #ifndef USE_ROCM
  57. torch::Tensor awq_gemm(
  58. torch::Tensor _in_feats,
  59. torch::Tensor _kernel,
  60. torch::Tensor _scaling_factors,
  61. torch::Tensor _zeros,
  62. int split_k_iters);
  63. #endif
  64. void squeezellm_gemm(
  65. torch::Tensor vec,
  66. torch::Tensor mat,
  67. torch::Tensor mul,
  68. torch::Tensor lookup_table);
  69. torch::Tensor gptq_gemm(
  70. torch::Tensor a,
  71. torch::Tensor b_q_weight,
  72. torch::Tensor b_gptq_qzeros,
  73. torch::Tensor b_gptq_scales,
  74. torch::Tensor b_g_idx,
  75. bool use_exllama,
  76. int bit);
  77. void gptq_shuffle(
  78. torch::Tensor q_weight,
  79. torch::Tensor q_perm,
  80. int bit);