#include "core/registration.h" #include "rocm/ops.h" // Note on op signatures: // The X_meta signatures are for the meta functions corresponding to op X. // They must be kept in sync with the signature for X. Generally, only // functions that return Tensors require a meta function. // // See the following links for detailed docs on op registration and function // schemas. // https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Aphrodite custom ops for rocm // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. rocm_ops.def( "paged_attention(Tensor! out, Tensor exp_sums," " Tensor max_logits, Tensor tmp_out," " Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads," " float scale, Tensor block_tables," " Tensor context_lens, int block_size," " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," " float k_scale, float v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)