123456789101112 |
- #include "../core/registration.h"
- #include "moe_ops.h"
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
- // Apply topk softmax to the gating outputs.
- m.def(
- "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
- "token_expert_indices, Tensor gating_output) -> ()");
- m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
- }
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|