torch_bindings.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. #include "cache.h"
  2. #include "cuda_utils.h"
  3. #include "ops.h"
  4. #include "registration.h"
  5. #include "quantization/quant_ops.h"
  6. #include <torch/library.h>
  7. // Note on op signatures:
  8. // The X_meta signatures are for the meta functions corresponding to op X.
  9. // They must be kept in sync with the signature for X. Generally, only
  10. // functions that return Tensors require a meta function.
  11. //
  12. // See the following links for detailed docs on op registration and function
  13. // schemas.
  14. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
  15. // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
  16. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  17. // Aphrodite custom ops
  18. // Attention ops
  19. // Compute the attention between an input query and the cached
  20. // keys/values using PagedAttention.
  21. ops.def(
  22. "paged_attention_v1("
  23. " Tensor! out, Tensor query, Tensor key_cache,"
  24. " Tensor value_cache, int num_kv_heads, float scale,"
  25. " Tensor block_tables, Tensor seq_lens, int block_size,"
  26. " int max_seq_len, Tensor? alibi_slopes,"
  27. " str kv_cache_dtype, float kv_scale, int tp_rank,"
  28. " int blocksparse_local_blocks,"
  29. " int blocksparse_vert_stride, int blocksparse_block_size,"
  30. " int blocksparse_head_sliding_step) -> ()");
  31. ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
  32. // PagedAttention V2.
  33. ops.def(
  34. "paged_attention_v2("
  35. " Tensor! out, Tensor exp_sums, Tensor max_logits,"
  36. " Tensor tmp_out, Tensor query, Tensor key_cache,"
  37. " Tensor value_cache, int num_kv_heads, float scale,"
  38. " Tensor block_tables, Tensor seq_lens, int block_size,"
  39. " int max_seq_len, Tensor? alibi_slopes,"
  40. " str kv_cache_dtype, float kv_scale, int tp_rank,"
  41. " int blocksparse_local_blocks,"
  42. " int blocksparse_vert_stride, int blocksparse_block_size,"
  43. " int blocksparse_head_sliding_step) -> ()");
  44. ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
  45. // Activation ops
  46. // Activation function used in SwiGLU.
  47. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  48. ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
  49. // Activation function used in GeGLU with `none` approximation.
  50. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  51. ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
  52. // Activation function used in GeGLU with `tanh` approximation.
  53. ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  54. ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
  55. // GELU implementation used in GPT-2.
  56. ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  57. ops.impl("gelu_new", torch::kCUDA, &gelu_new);
  58. // Approximate GELU implementation.
  59. ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  60. ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
  61. // Layernorm
  62. // Apply Root Mean Square (RMS) Normalization to the input tensor.
  63. ops.def(
  64. "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
  65. "()");
  66. ops.impl("rms_norm", torch::kCUDA, &rms_norm);
  67. // In-place fused Add and RMS Normalization.
  68. ops.def(
  69. "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
  70. "float epsilon) -> ()");
  71. ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
  72. // Rotary embedding
  73. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  74. ops.def(
  75. "rotary_embedding(Tensor positions, Tensor! query,"
  76. " Tensor! key, int head_size,"
  77. " Tensor cos_sin_cache, bool is_neox) -> ()");
  78. ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
  79. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
  80. // (supports multiple loras).
  81. ops.def(
  82. "batched_rotary_embedding(Tensor positions, Tensor! query,"
  83. " Tensor! key, int head_size,"
  84. " Tensor cos_sin_cache, bool is_neox,"
  85. " int rot_dim,"
  86. " Tensor cos_sin_cache_offsets) -> ()");
  87. ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
  88. // Quantization ops
  89. #ifndef USE_ROCM
  90. // Quantized GEMM for AQLM.
  91. ops.def("aqlm_gemm", &aqlm_gemm);
  92. ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
  93. // Decompression method for AQLM.
  94. ops.def("aqlm_dequant", &aqlm_dequant);
  95. ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
  96. // Quantized GEMM for AWQ.
  97. ops.def("awq_gemm", &awq_gemm);
  98. ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
  99. // Dequantization for AWQ.
  100. ops.def("awq_dequantize", &awq_dequantize);
  101. ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
  102. // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
  103. ops.def("marlin_gemm", &marlin_gemm);
  104. ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
  105. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
  106. ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
  107. ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
  108. // gptq_marlin Optimized Quantized GEMM for GPTQ.
  109. ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
  110. ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
  111. // gptq_marlin repack from GPTQ.
  112. ops.def("gptq_marlin_repack", &gptq_marlin_repack);
  113. ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
  114. // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
  115. // quantization.
  116. ops.def(
  117. "cutlass_scaled_mm(Tensor! out, Tensor a,"
  118. " Tensor b, Tensor a_scales,"
  119. " Tensor b_scales) -> ()");
  120. ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
  121. // Check if cutlass scaled_mm is supported for CUDA devices of the given
  122. // capability
  123. ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
  124. ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
  125. &cutlass_scaled_mm_supports_fp8);
  126. // QuIP# GEMV
  127. ops.def("quip_gemv", &e8p_mm_origorder);
  128. ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
  129. // QuIP# Decompress
  130. ops.def("quip_decompress", &decompress_e8p_origorder);
  131. ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
  132. #endif
  133. // Quantized GEMM for GPTQ.
  134. ops.def("gptq_gemm", &gptq_gemm);
  135. ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
  136. // Post processing for GPTQ.
  137. ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
  138. ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
  139. // Quantized GEMM for SqueezeLLM.
  140. ops.def(
  141. "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
  142. "lookup_table) -> ()");
  143. ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
  144. // Compute FP8 quantized tensor for given scaling factor.
  145. ops.def(
  146. "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
  147. ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
  148. // Compute FP8 quantized tensor and scaling factor.
  149. ops.def(
  150. "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  151. "()");
  152. ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
  153. // Aligning the number of tokens to be processed by each expert such
  154. // that it is divisible by the block size.
  155. ops.def(
  156. "moe_align_block_size(Tensor topk_ids, int num_experts,"
  157. " int block_size, Tensor! sorted_token_ids,"
  158. " Tensor! experts_ids,"
  159. " Tensor! num_tokens_post_pad) -> ()");
  160. ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
  161. // Compute int8 quantized tensor for given scaling factor.
  162. ops.def(
  163. "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
  164. "()");
  165. ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
  166. // Compute int8 quantized tensor and scaling factor
  167. ops.def(
  168. "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
  169. "()");
  170. ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
  171. &dynamic_scaled_int8_quant);
  172. }
  173. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  174. // Cache ops
  175. // Swap in (out) the cache blocks from src to dst.
  176. cache_ops.def(
  177. "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  178. cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
  179. // Copy the cache blocks from src to dst.
  180. cache_ops.def(
  181. "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
  182. "block_mapping) -> ()");
  183. cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
  184. // Reshape the key and value tensors and cache them.
  185. cache_ops.def(
  186. "reshape_and_cache(Tensor key, Tensor value,"
  187. " Tensor! key_cache, Tensor! value_cache,"
  188. " Tensor slot_mapping,"
  189. " str kv_cache_dtype,"
  190. " float kv_scale) -> ()");
  191. cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
  192. // Reshape the key and value tensors and cache them.
  193. cache_ops.def(
  194. "reshape_and_cache_flash(Tensor key, Tensor value,"
  195. " Tensor! key_cache,"
  196. " Tensor! value_cache,"
  197. " Tensor slot_mapping,"
  198. " str kv_cache_dtype) -> ()");
  199. cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
  200. &reshape_and_cache_flash);
  201. // Convert the key and value cache to fp8 data type.
  202. cache_ops.def(
  203. "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
  204. "kv_cache_dtype) -> ()");
  205. cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
  206. }
  207. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
  208. // Cuda utils
  209. // Gets the specified device attribute.
  210. cuda_utils.def("get_device_attribute", &get_device_attribute);
  211. cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
  212. // Gets the maximum shared memory per block device attribute.
  213. cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
  214. &get_max_shared_memory_per_block_device_attribute);
  215. cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
  216. torch::kCUDA,
  217. &get_max_shared_memory_per_block_device_attribute);
  218. }
  219. #ifndef USE_ROCM
  220. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
  221. // Custom all-reduce kernels
  222. custom_ar.def("init_custom_ar", &init_custom_ar);
  223. custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
  224. custom_ar.def("should_custom_ar", &should_custom_ar);
  225. custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
  226. custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
  227. custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
  228. custom_ar.def(
  229. "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
  230. "()");
  231. custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
  232. custom_ar.def("dispose", &dispose);
  233. custom_ar.impl("dispose", torch::kCPU, &dispose);
  234. custom_ar.def("meta_size", &meta_size);
  235. custom_ar.impl("meta_size", torch::kCPU, &meta_size);
  236. custom_ar.def("register_buffer", &register_buffer);
  237. custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
  238. custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  239. custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
  240. &get_graph_buffer_ipc_meta);
  241. custom_ar.def("register_graph_buffers", &register_graph_buffers);
  242. custom_ar.impl("register_graph_buffers", torch::kCPU,
  243. &register_graph_buffers);
  244. }
  245. #endif
  246. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)