1
0

torch_bindings.cpp 437 B

12345678910111213141516
  1. #ifdef _WIN32
  2. #include <crtdefs.h>
  3. #endif
  4. #include "../core/registration.h"
  5. #include "moe_ops.h"
  6. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
  7. // Apply topk softmax to the gating outputs.
  8. m.def(
  9. "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
  10. "token_expert_indices, Tensor gating_output) -> ()");
  11. m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
  12. }
  13. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)