1
0

ops.h 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #pragma once
  2. #include <torch/extension.h>
  3. void paged_attention_v1(
  4. torch::Tensor& out,
  5. torch::Tensor& query,
  6. torch::Tensor& key_cache,
  7. torch::Tensor& value_cache,
  8. int num_kv_heads,
  9. float scale,
  10. torch::Tensor& block_tables,
  11. torch::Tensor& context_lens,
  12. int block_size,
  13. int max_context_len,
  14. const c10::optional<torch::Tensor>& alibi_slopes,
  15. const std::string& kv_cache_dtype,
  16. float kv_scale);
  17. void paged_attention_v2(
  18. torch::Tensor& out,
  19. torch::Tensor& exp_sums,
  20. torch::Tensor& max_logits,
  21. torch::Tensor& tmp_out,
  22. torch::Tensor& query,
  23. torch::Tensor& key_cache,
  24. torch::Tensor& value_cache,
  25. int num_kv_heads,
  26. float scale,
  27. torch::Tensor& block_tables,
  28. torch::Tensor& context_lens,
  29. int block_size,
  30. int max_context_len,
  31. const c10::optional<torch::Tensor>& alibi_slopes,
  32. const std::string& kv_cache_dtype,
  33. float kv_scale);
  34. void rms_norm(
  35. torch::Tensor& out,
  36. torch::Tensor& input,
  37. torch::Tensor& weight,
  38. float epsilon);
  39. void fused_add_rms_norm(
  40. torch::Tensor& input,
  41. torch::Tensor& residual,
  42. torch::Tensor& weight,
  43. float epsilon);
  44. void rotary_embedding(
  45. torch::Tensor& positions,
  46. torch::Tensor& query,
  47. torch::Tensor& key,
  48. int head_size,
  49. torch::Tensor& cos_sin_cache,
  50. bool is_neox);
  51. void batched_rotary_embedding(
  52. torch::Tensor& positions,
  53. torch::Tensor& query,
  54. torch::Tensor& key,
  55. int head_size,
  56. torch::Tensor& cos_sin_cache,
  57. bool is_neox,
  58. int rot_dim,
  59. torch::Tensor& cos_sin_cache_offsets);
  60. void silu_and_mul(
  61. torch::Tensor& out,
  62. torch::Tensor& input);
  63. void gelu_and_mul(
  64. torch::Tensor& out,
  65. torch::Tensor& input);
  66. void gelu_tanh_and_mul(
  67. torch::Tensor& out,
  68. torch::Tensor& input);
  69. void gelu_new(
  70. torch::Tensor& out,
  71. torch::Tensor& input);
  72. void gelu_fast(
  73. torch::Tensor& out,
  74. torch::Tensor& input);
  75. #ifndef USE_ROCM
  76. torch::Tensor awq_gemm(
  77. torch::Tensor _in_feats,
  78. torch::Tensor _kernel,
  79. torch::Tensor _scaling_factors,
  80. torch::Tensor _zeros,
  81. int split_k_iters);
  82. torch::Tensor autoquant_s4_f16_gemm(
  83. torch::Tensor _in_feats,
  84. torch::Tensor _kernel,
  85. torch::Tensor _scales_zeros);
  86. void autoquant_convert_s4_k_m8(
  87. torch::Tensor _weight_dest,
  88. torch::Tensor _quant_scales_zeros_dest,
  89. torch::Tensor _workspace,
  90. torch::Tensor _quant_weight_src,
  91. torch::Tensor _quant_scales,
  92. torch::Tensor _quant_zeros,
  93. int m,
  94. int k,
  95. int group_size);
  96. torch::Tensor aqlm_gemm(
  97. const torch::Tensor& input,
  98. const torch::Tensor& codes,
  99. const torch::Tensor& codebooks,
  100. const torch::Tensor& scales,
  101. const torch::Tensor& codebook_partition_sizes,
  102. const std::optional<torch::Tensor>& bias
  103. );
  104. at::Tensor e8p_mm_origorder(
  105. const at::Tensor& A,
  106. const at::Tensor& B,
  107. const at::Tensor& CB);
  108. void decompress_e8p_origorder(
  109. torch::Tensor YIs,
  110. torch::Tensor CB,
  111. torch::Tensor &Y
  112. );
  113. torch::Tensor awq_dequantize(
  114. torch::Tensor _kernel,
  115. torch::Tensor _scaling_factors,
  116. torch::Tensor _zeros,
  117. int split_k_iters,
  118. int thx,
  119. int thy);
  120. torch::Tensor awq_group_gemm(
  121. torch::Tensor _in_feats,
  122. torch::Tensor _kernel,
  123. torch::Tensor _scaling_factors,
  124. torch::Tensor _zeros,
  125. torch::Tensor _topk_weights,
  126. torch::Tensor _sorted_token_ids_ptr,
  127. torch::Tensor _expert_ids_ptr,
  128. torch::Tensor _num_tokens_post_padded,
  129. bool mul_weights,
  130. int split_k_iters);
  131. torch::Tensor marlin_gemm(
  132. torch::Tensor& a,
  133. torch::Tensor& b_q_weight,
  134. torch::Tensor& b_scales,
  135. torch::Tensor& workspace,
  136. int64_t size_m,
  137. int64_t size_n,
  138. int64_t size_k);
  139. #endif
  140. void squeezellm_gemm(
  141. torch::Tensor vec,
  142. torch::Tensor mat,
  143. torch::Tensor mul,
  144. torch::Tensor lookup_table);
  145. torch::Tensor gptq_gemm(
  146. torch::Tensor a,
  147. torch::Tensor b_q_weight,
  148. torch::Tensor b_gptq_qzeros,
  149. torch::Tensor b_gptq_scales,
  150. torch::Tensor b_g_idx,
  151. bool use_exllama,
  152. int bit);
  153. void gptq_shuffle(
  154. torch::Tensor q_weight,
  155. torch::Tensor q_perm,
  156. int bit);
  157. torch::Tensor ggml_dequantize(
  158. torch::Tensor X,
  159. int8_t type,
  160. int64_t m,
  161. int64_t n
  162. );
  163. torch::Tensor ggml_mul_mat_vec(
  164. torch::Tensor W, // quant weight
  165. torch::Tensor X, // input
  166. int8_t type,
  167. int64_t m
  168. );
  169. torch::Tensor ggml_mul_mat_vec_a8(
  170. torch::Tensor W, // quant weight
  171. torch::Tensor X, // input
  172. int8_t type,
  173. int64_t row
  174. );
  175. torch::Tensor ggml_mul_mat_a8(
  176. torch::Tensor W, // quant weight
  177. torch::Tensor X, // input
  178. int8_t type,
  179. int64_t row
  180. );
  181. uintptr_t make_q_matrix(
  182. torch::Tensor q_weight,
  183. torch::Tensor q_perm,
  184. torch::Tensor q_invperm,
  185. torch::Tensor q_scale,
  186. torch::Tensor q_scale_max,
  187. torch::Tensor q_groups,
  188. torch::Tensor q_group_map
  189. );
  190. torch::Tensor exl2_gemm(
  191. torch::Tensor a,
  192. uintptr_t b
  193. );
  194. torch::Tensor group_gptq_gemm(
  195. torch::Tensor a,
  196. torch::Tensor b_q_weight,
  197. torch::Tensor b_gptq_qzeros,
  198. torch::Tensor b_gptq_scales,
  199. torch::Tensor b_g_idx,
  200. torch::Tensor topk_weights,
  201. torch::Tensor sorted_token_ids_ptr,
  202. torch::Tensor expert_ids_ptr,
  203. torch::Tensor num_tokens_post_padded,
  204. bool mul_weights,
  205. bool use_exllama
  206. );
  207. torch::Tensor dequant_gptq(
  208. torch::Tensor b_q_weight,
  209. torch::Tensor b_gptq_qzeros,
  210. torch::Tensor b_gptq_scales,
  211. torch::Tensor b_g_idx,
  212. int bits,
  213. bool use_exllama
  214. );
  215. void moe_align_block_size(
  216. torch::Tensor topk_ids,
  217. int num_experts,
  218. int block_size,
  219. torch::Tensor sorted_token_ids,
  220. torch::Tensor expert_ids,
  221. torch::Tensor num_tokens_post_pad
  222. );
  223. #ifndef USE_ROCM
  224. using fptr_t = uint64_t;
  225. fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
  226. const std::vector<std::string> &handles,
  227. const std::vector<int64_t> &offsets, int rank,
  228. bool full_nvlink);
  229. bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
  230. bool full_nvlink);
  231. void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
  232. void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
  233. torch::Tensor &out);
  234. void dispose(fptr_t _fa);
  235. int meta_size();
  236. void register_buffer(fptr_t _fa, torch::Tensor &t,
  237. const std::vector<std::string> &handles,
  238. const std::vector<int64_t> &offsets);
  239. std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
  240. void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
  241. const std::vector<std::vector<int64_t>> &offsets);
  242. #endif