ops.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. #pragma once
  2. #include <torch/extension.h>
  3. void paged_attention_v1(
  4. torch::Tensor& out,
  5. torch::Tensor& query,
  6. torch::Tensor& key_cache,
  7. torch::Tensor& value_cache,
  8. int num_kv_heads,
  9. float scale,
  10. torch::Tensor& block_tables,
  11. torch::Tensor& context_lens,
  12. int block_size,
  13. int max_context_len,
  14. const c10::optional<torch::Tensor>& alibi_slopes,
  15. const std::string& kv_cache_dtype,
  16. float kv_scale);
  17. void paged_attention_v2(
  18. torch::Tensor& out,
  19. torch::Tensor& exp_sums,
  20. torch::Tensor& max_logits,
  21. torch::Tensor& tmp_out,
  22. torch::Tensor& query,
  23. torch::Tensor& key_cache,
  24. torch::Tensor& value_cache,
  25. int num_kv_heads,
  26. float scale,
  27. torch::Tensor& block_tables,
  28. torch::Tensor& context_lens,
  29. int block_size,
  30. int max_context_len,
  31. const c10::optional<torch::Tensor>& alibi_slopes,
  32. const std::string& kv_cache_dtype,
  33. float kv_scale);
  34. void rms_norm(
  35. torch::Tensor& out,
  36. torch::Tensor& input,
  37. torch::Tensor& weight,
  38. float epsilon);
  39. void fused_add_rms_norm(
  40. torch::Tensor& input,
  41. torch::Tensor& residual,
  42. torch::Tensor& weight,
  43. float epsilon);
  44. void rotary_embedding(
  45. torch::Tensor& positions,
  46. torch::Tensor& query,
  47. torch::Tensor& key,
  48. int head_size,
  49. torch::Tensor& cos_sin_cache,
  50. bool is_neox);
  51. void batched_rotary_embedding(
  52. torch::Tensor& positions,
  53. torch::Tensor& query,
  54. torch::Tensor& key,
  55. int head_size,
  56. torch::Tensor& cos_sin_cache,
  57. bool is_neox,
  58. int rot_dim,
  59. torch::Tensor& cos_sin_cache_offsets);
  60. void silu_and_mul(
  61. torch::Tensor& out,
  62. torch::Tensor& input);
  63. void gelu_and_mul(
  64. torch::Tensor& out,
  65. torch::Tensor& input);
  66. void gelu_tanh_and_mul(
  67. torch::Tensor& out,
  68. torch::Tensor& input);
  69. void gelu_new(
  70. torch::Tensor& out,
  71. torch::Tensor& input);
  72. void gelu_fast(
  73. torch::Tensor& out,
  74. torch::Tensor& input);
  75. void moe_align_block_size(
  76. torch::Tensor topk_ids,
  77. int num_experts,
  78. int block_size,
  79. torch::Tensor sorted_token_ids,
  80. torch::Tensor expert_ids,
  81. torch::Tensor num_tokens_post_pad
  82. );
  83. #ifndef USE_ROCM
  84. using fptr_t = uint64_t;
  85. fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
  86. const std::vector<std::string> &handles,
  87. const std::vector<int64_t> &offsets, int rank,
  88. bool full_nvlink);
  89. bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
  90. bool full_nvlink);
  91. void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
  92. void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
  93. torch::Tensor &out);
  94. void dispose(fptr_t _fa);
  95. int meta_size();
  96. void register_buffer(fptr_t _fa, torch::Tensor &t,
  97. const std::vector<std::string> &handles,
  98. const std::vector<int64_t> &offsets);
  99. std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
  100. void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
  101. const std::vector<std::vector<int64_t>> &offsets);
  102. #endif