torch_bindings.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #include "cache.h"
  2. #include "ops.h"
  3. #include "core/registration.h"
  4. #include <torch/library.h>
  5. std::string init_cpu_threads_env(const std::string& cpu_ids);
  6. void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
  7. const torch::Tensor& b, const torch::Tensor& a_scales,
  8. const torch::Tensor& b_scales,
  9. const c10::optional<torch::Tensor>& bias);
  10. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  11. // Aphrodite custom ops
  12. // Attention ops
  13. // Compute the attention between an input query and the cached keys/values
  14. // using PagedAttention.
  15. ops.def(
  16. "paged_attention_v1("
  17. " Tensor! out, Tensor query, Tensor key_cache,"
  18. " Tensor value_cache, int num_kv_heads, float scale,"
  19. " Tensor block_tables, Tensor seq_lens, int block_size,"
  20. " int max_seq_len, Tensor? alibi_slopes,"
  21. " str kv_cache_dtype, float k_scale, float v_scale,"
  22. " int tp_rank, int blocksparse_local_blocks,"
  23. " int blocksparse_vert_stride, int blocksparse_block_size,"
  24. " int blocksparse_head_sliding_step) -> ()");
  25. ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
  26. // PagedAttention V2.
  27. ops.def(
  28. "paged_attention_v2("
  29. " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
  30. " Tensor! tmp_out, Tensor query, Tensor key_cache,"
  31. " Tensor value_cache, int num_kv_heads, float scale,"
  32. " Tensor block_tables, Tensor seq_lens, int block_size,"
  33. " int max_seq_len, Tensor? alibi_slopes,"
  34. " str kv_cache_dtype, float k_scale, float v_scale,"
  35. " int tp_rank, int blocksparse_local_blocks,"
  36. " int blocksparse_vert_stride, int blocksparse_block_size,"
  37. " int blocksparse_head_sliding_step) -> ()");
  38. ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
  39. // Activation ops
  40. // Activation function used in SwiGLU.
  41. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  42. ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
  43. // Activation function used in GeGLU with `none` approximation.
  44. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  45. ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
  46. // Activation function used in GeGLU with `tanh` approximation.
  47. ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  48. ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
  49. // GELU implementation used in GPT-2.
  50. ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  51. ops.impl("gelu_new", torch::kCPU, &gelu_new);
  52. // Approximate GELU implementation.
  53. ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  54. ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
  55. // Quick GELU implementation.
  56. ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  57. ops.impl("gelu_quick", torch::kCPU, &gelu_quick);
  58. // Layernorm
  59. // Apply Root Mean Square (RMS) Normalization to the input tensor.
  60. ops.def(
  61. "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
  62. "()");
  63. ops.impl("rms_norm", torch::kCPU, &rms_norm);
  64. // In-place fused Add and RMS Normalization.
  65. ops.def(
  66. "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
  67. "float epsilon) -> ()");
  68. ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
  69. // Rotary embedding
  70. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  71. ops.def(
  72. "rotary_embedding(Tensor positions, Tensor! query,"
  73. " Tensor! key, int head_size,"
  74. " Tensor cos_sin_cache, bool is_neox) -> ()");
  75. ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
  76. // Quantization
  77. #ifdef __AVX512F__
  78. // Compute int8 quantized tensor for given scaling factor.
  79. ops.def(
  80. "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
  81. "Tensor? azp) -> ()");
  82. ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
  83. // Compute int8 quantized tensor and scaling factor
  84. ops.def(
  85. "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
  86. "Tensor!? azp) -> ()");
  87. ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
  88. &dynamic_scaled_int8_quant);
  89. // W8A8 GEMM, supporting symmetric per-tensor or per-row/column
  90. // quantization.
  91. ops.def(
  92. "cutlass_scaled_mm(Tensor! out, Tensor a,"
  93. " Tensor b, Tensor a_scales,"
  94. " Tensor b_scales, Tensor? bias) -> ()");
  95. ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
  96. #endif
  97. }
  98. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  99. // Cache ops
  100. // Swap in (out) the cache blocks from src to dst.
  101. cache_ops.def(
  102. "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  103. cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
  104. // Copy the cache blocks from src to dst.
  105. cache_ops.def(
  106. "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
  107. "Tensor block_mapping) -> ()");
  108. cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
  109. // Reshape the key and value tensors and cache them.
  110. cache_ops.def(
  111. "reshape_and_cache(Tensor key, Tensor value,"
  112. " Tensor! key_cache, Tensor! value_cache,"
  113. " Tensor slot_mapping,"
  114. " str kv_cache_dtype,"
  115. " float k_scale, float v_scale) -> ()");
  116. cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
  117. }
  118. TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
  119. // CPU utils
  120. utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
  121. }
  122. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)