1
0

ops.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #pragma once
  2. #include <optional>
  3. #include <torch/library.h>
  4. void paged_attention_v1(
  5. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
  6. torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
  7. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  8. int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
  9. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  10. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  11. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  12. const int64_t blocksparse_head_sliding_step);
  13. void paged_attention_v2(
  14. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  15. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  16. torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
  17. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  18. int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
  19. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  20. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  21. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  22. const int64_t blocksparse_head_sliding_step);
  23. void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
  24. double epsilon);
  25. void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
  26. torch::Tensor& weight, double epsilon);
  27. void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
  28. torch::Tensor& key, int64_t head_size,
  29. torch::Tensor& cos_sin_cache, bool is_neox);
  30. void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
  31. torch::Tensor& key, int64_t head_size,
  32. torch::Tensor& cos_sin_cache, bool is_neox,
  33. int64_t rot_dim,
  34. torch::Tensor& cos_sin_cache_offsets);
  35. void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
  36. void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
  37. void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
  38. void gelu_new(torch::Tensor& out, torch::Tensor& input);
  39. void gelu_fast(torch::Tensor& out, torch::Tensor& input);
  40. void gelu_quick(torch::Tensor& out, torch::Tensor& input);
  41. void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
  42. torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
  43. torch::Tensor& input_positions, torch::Tensor& seq_lens,
  44. torch::Tensor& slot_mapping, torch::Tensor& block_tables);
  45. void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
  46. int64_t block_size, torch::Tensor sorted_token_ids,
  47. torch::Tensor expert_ids,
  48. torch::Tensor num_tokens_post_pad);
  49. std::vector<torch::Tensor> selective_scan_fwd(
  50. const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
  51. const torch::Tensor& B, const torch::Tensor& C,
  52. const c10::optional<torch::Tensor>& D_,
  53. const c10::optional<torch::Tensor>& z_,
  54. const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
  55. const c10::optional<torch::Tensor>& index_,
  56. const c10::optional<torch::Tensor>& x);
  57. at::Tensor causal_conv1d_update(const at::Tensor& x,
  58. const at::Tensor& conv_state,
  59. const at::Tensor& weight,
  60. const c10::optional<at::Tensor>& bias_,
  61. bool silu_activation);
  62. at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
  63. const c10::optional<at::Tensor>& bias_,
  64. const c10::optional<at::Tensor>& seq_idx_,
  65. const c10::optional<at::Tensor>& seq_pos_idx_,
  66. const c10::optional<at::Tensor>& initial_states_,
  67. const c10::optional<at::Tensor>& final_states_out_,
  68. bool silu_activation);
  69. #ifndef USE_ROCM
  70. using fptr_t = int64_t;
  71. fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
  72. const std::vector<std::string>& handles,
  73. const std::vector<int64_t>& offsets, int64_t rank,
  74. bool full_nvlink);
  75. bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
  76. bool full_nvlink);
  77. void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
  78. void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
  79. torch::Tensor& out);
  80. void dispose(fptr_t _fa);
  81. int64_t meta_size();
  82. void register_buffer(fptr_t _fa, torch::Tensor& t,
  83. const std::vector<std::string>& handles,
  84. const std::vector<int64_t>& offsets);
  85. std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
  86. fptr_t _fa);
  87. void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
  88. const std::vector<std::vector<int64_t>>& offsets);
  89. #endif