torch_bindings.cpp 1.4 KB

123456789101112131415161718192021222324252627282930
  1. #include "core/registration.h"
  2. #include "rocm/ops.h"
  3. // Note on op signatures:
  4. // The X_meta signatures are for the meta functions corresponding to op X.
  5. // They must be kept in sync with the signature for X. Generally, only
  6. // functions that return Tensors require a meta function.
  7. //
  8. // See the following links for detailed docs on op registration and function
  9. // schemas.
  10. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
  11. // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
  12. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
  13. // Aphrodite custom ops for rocm
  14. // Custom attention op
  15. // Compute the attention between an input query and the cached
  16. // keys/values using PagedAttention.
  17. rocm_ops.def(
  18. "paged_attention(Tensor! out, Tensor exp_sums,"
  19. " Tensor max_logits, Tensor tmp_out,"
  20. " Tensor query, Tensor key_cache,"
  21. " Tensor value_cache, int num_kv_heads,"
  22. " float scale, Tensor block_tables,"
  23. " Tensor context_lens, int block_size,"
  24. " int max_context_len,"
  25. " Tensor? alibi_slopes,"
  26. " str kv_cache_dtype,"
  27. " float k_scale, float v_scale) -> ()");
  28. rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
  29. }
  30. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)