moe_ops.h 204 B

123456789
  1. #pragma once
  2. #include <torch/extension.h>
  3. void topk_softmax(
  4. torch::Tensor& topk_weights,
  5. torch::Tensor& topk_indices,
  6. torch::Tensor& token_expert_indices,
  7. torch::Tensor& gating_output);