xqa_ops.h 533 B

1234567891011
  1. #pragma once
  2. #include <torch/all.h>
  3. void xqa_paged_attention(torch::Tensor& out, torch::Tensor& query,
  4. torch::Tensor& key_value_cache, int64_t num_heads,
  5. int64_t num_kv_heads, int64_t rotary_embedding_dim,
  6. double scale, torch::Tensor& block_tables,
  7. torch::Tensor& seq_lens, int64_t block_size,
  8. int64_t max_seq_len, const std::string kv_cache_dtype,
  9. double k_scale, double v_scale);