#include void rotary_embedding( torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int head_size, torch::Tensor& cos_sin_cache, bool is_neox); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "rotary_embedding", &rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); }