moe_ops.h 217 B

1234567
  1. #pragma once
  2. #include <torch/all.h>
  3. void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
  4. torch::Tensor& token_expert_indices,
  5. torch::Tensor& gating_output);