ops.h 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #pragma once
  2. #include <optional>
  3. #include <torch/library.h>
  4. void paged_attention_v1(
  5. torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
  6. torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
  7. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  8. int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
  9. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  10. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  11. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  12. const int64_t blocksparse_head_sliding_step);
  13. void paged_attention_v2(
  14. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  15. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  16. torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
  17. torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
  18. int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
  19. const std::string& kv_cache_dtype, double k_scale, double v_scale,
  20. const int64_t tp_rank, const int64_t blocksparse_local_blocks,
  21. const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
  22. const int64_t blocksparse_head_sliding_step);
  23. void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
  24. double epsilon);
  25. void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
  26. torch::Tensor& weight, double epsilon);
  27. void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
  28. torch::Tensor& key, int64_t head_size,
  29. torch::Tensor& cos_sin_cache, bool is_neox);
  30. void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
  31. torch::Tensor& key, int64_t head_size,
  32. torch::Tensor& cos_sin_cache, bool is_neox,
  33. int64_t rot_dim,
  34. torch::Tensor& cos_sin_cache_offsets);
  35. void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
  36. void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
  37. void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
  38. void gelu_new(torch::Tensor& out, torch::Tensor& input);
  39. void gelu_fast(torch::Tensor& out, torch::Tensor& input);
  40. void gelu_quick(torch::Tensor& out, torch::Tensor& input);
  41. void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
  42. int64_t block_size, torch::Tensor& input_tokens,
  43. torch::Tensor& sampled_token_ids,
  44. torch::Tensor& input_positions,
  45. torch::Tensor& seq_lens,
  46. torch::Tensor& slot_mapping,
  47. torch::Tensor& block_tables);
  48. void advance_step_flashinfer(
  49. int64_t num_seqs, int64_t num_queries, int64_t block_size,
  50. torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
  51. torch::Tensor& input_positions, torch::Tensor& seq_lens,
  52. torch::Tensor& slot_mapping, torch::Tensor& block_tables,
  53. torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
  54. torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
  55. void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
  56. int64_t block_size, torch::Tensor sorted_token_ids,
  57. torch::Tensor expert_ids,
  58. torch::Tensor num_tokens_post_pad);
  59. #ifndef USE_ROCM
  60. using fptr_t = int64_t;
  61. fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
  62. const std::vector<std::string>& handles,
  63. const std::vector<int64_t>& offsets, int64_t rank,
  64. bool full_nvlink);
  65. void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
  66. void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
  67. torch::Tensor& out);
  68. void dispose(fptr_t _fa);
  69. int64_t meta_size();
  70. void register_buffer(fptr_t _fa, torch::Tensor& t,
  71. const std::vector<std::string>& handles,
  72. const std::vector<int64_t>& offsets);
  73. std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
  74. fptr_t _fa);
  75. void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
  76. const std::vector<std::vector<int64_t>>& offsets);
  77. std::vector<torch::Tensor> selective_scan_fwd(
  78. const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
  79. const torch::Tensor& B, const torch::Tensor& C,
  80. const c10::optional<torch::Tensor>& D_,
  81. const c10::optional<torch::Tensor>& z_,
  82. const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
  83. const c10::optional<torch::Tensor>& index_,
  84. const c10::optional<torch::Tensor>& x);
  85. at::Tensor causal_conv1d_update(
  86. const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
  87. const c10::optional<at::Tensor>& bias, bool silu_activation,
  88. const c10::optional<at::Tensor>& conv_state_indices);
  89. at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
  90. const c10::optional<at::Tensor>& bias_,
  91. const c10::optional<at::Tensor>& seq_idx_,
  92. const c10::optional<at::Tensor>& initial_states_,
  93. const c10::optional<at::Tensor>& final_states_out_,
  94. bool silu_activation);
  95. torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
  96. // Sampling kernels
  97. torch::Tensor sampling_from_probs(torch::Tensor probs,
  98. torch::Tensor uniform_samples,
  99. bool deterministic);
  100. std::vector<torch::Tensor> top_p_sampling_from_probs(
  101. torch::Tensor probs, torch::Tensor uniform_samples,
  102. std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
  103. bool deterministic);
  104. std::vector<torch::Tensor> top_k_sampling_from_probs(
  105. torch::Tensor probs, torch::Tensor uniform_samples,
  106. std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
  107. bool deterministic);
  108. std::vector<torch::Tensor> min_p_sampling_from_probs(
  109. torch::Tensor probs, torch::Tensor uniform_samples,
  110. std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
  111. bool deterministic);
  112. std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
  113. torch::Tensor probs, torch::Tensor uniform_samples,
  114. std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
  115. std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
  116. bool deterministic);
  117. torch::Tensor top_p_renorm_prob(torch::Tensor probs,
  118. std::optional<torch::Tensor> maybe_top_p_arr,
  119. double top_p_val);
  120. torch::Tensor top_k_renorm_prob(torch::Tensor probs,
  121. std::optional<torch::Tensor> maybe_top_k_arr,
  122. int64_t top_k_val);
  123. torch::Tensor top_k_mask_logits(torch::Tensor logits,
  124. std::optional<torch::Tensor> maybe_top_k_arr,
  125. int64_t top_k_val);
  126. #endif