torch_bindings.cpp 392 B

123456789101112
  1. #include "../core/registration.h"
  2. #include "moe_ops.h"
  3. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
  4. // Apply topk softmax to the gating outputs.
  5. m.def(
  6. "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
  7. "token_expert_indices, Tensor gating_output) -> ()");
  8. m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
  9. }
  10. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)