12345678910111213141516 |
- #include <torch/extension.h>
- void rotary_embedding(
- torch::Tensor& positions,
- torch::Tensor& query,
- torch::Tensor& key,
- int head_size,
- torch::Tensor& cos_sin_cache,
- bool is_neox);
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def(
- "rotary_embedding",
- &rotary_embedding,
- "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
- }
|