pos_encoding.cpp 366 B

12345678910111213141516
  1. #include <torch/extension.h>
  2. void rotary_embedding(
  3. torch::Tensor& positions,
  4. torch::Tensor& query,
  5. torch::Tensor& key,
  6. int head_size,
  7. torch::Tensor& cos_sin_cache,
  8. bool is_neox);
  9. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  10. m.def(
  11. "rotary_embedding",
  12. &rotary_embedding,
  13. "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
  14. }