Ver código fonte

feat: fused top-k kernels for MoE (#273)

* feat: add top-k gating softmax kernel

* add ops

* make new kernels compilable

* switch to built-in shfl xor sync

* AlignedArray -> AccessType

* define idx

* fix compile errors

* integrate in triton kernel

* support mixtral and deepseek

* pylint
AlpinDale 1 ano atrás
pai
commit
7d6ba53602

+ 48 - 9
aphrodite/modeling/layers/triton_kernel/fused_moe.py

@@ -4,6 +4,7 @@ import triton
 import triton.language as tl
 
 from aphrodite._C import ops
+from aphrodite.common.utils import is_hip
 
 
 @triton.jit
@@ -231,12 +232,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
     )
 
 
-def fused_moe(hidden_states: torch.Tensor,
-              w1: torch.Tensor,
-              w2: torch.Tensor,
-              topk_weights: torch.Tensor,
-              topk_ids: torch.Tensor,
-              inplace=False):
+def fused_moe(
+    hidden_states: torch.Tensor,
+    w1: torch.Tensor,
+    w2: torch.Tensor,
+    gating_output: torch.Tensor,
+    topk: int,
+    renormalize: bool,
+    inplace: bool = False,
+) -> torch.Tensor:
     """
     This function computes a Mixture of Experts (MoE) layer using two sets of
     weights, w1 and w2, and top-k gating mechanism.
@@ -245,8 +249,9 @@ def fused_moe(hidden_states: torch.Tensor,
     - hidden_states (torch.Tensor): The input tensor to the MoE layer.
     - w1 (torch.Tensor): The first set of expert weights.
     - w2 (torch.Tensor): The second set of expert weights.
-    - topk_weights (torch.Tensor): The weights for the top-k selected experts.
-    - topk_ids (torch.Tensor): The indices of the top-k selected experts.
+    - gating_output (torch.Tensor): The output of the gating operation (before softmax).
+    - topk (int): The number of top-k experts to select.
+    - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
     - inplace (bool): If True, perform the operation in-place. Defaults to
         False.
 
@@ -254,7 +259,10 @@ def fused_moe(hidden_states: torch.Tensor,
     - torch.Tensor: The output tensor after applying the MoE layer.
     """
     # Check constraints.
-    assert hidden_states.shape[1] == w1.shape[2], 'Incompatible dimensions'
+    assert hidden_states.shape[0] == gating_output.shape[0], (
+        'Number of tokens mismatch')
+    assert hidden_states.shape[1] == w1.shape[2], 'Hidden size mismatch'
+    assert gating_output.shape[1] == w1.shape[0], 'Number of experts mismatch'
     assert hidden_states.is_contiguous(), 'Hidden_states must be contiguous'
     assert w1.is_contiguous(), 'Expert weights1 must be contiguous'
     assert w2.is_contiguous(), 'Expert weights2 must be contiguous'
@@ -262,6 +270,37 @@ def fused_moe(hidden_states: torch.Tensor,
     M, _ = hidden_states.shape
     E, N, _ = w1.shape
 
+    if is_hip():
+        # The MoE kernels are not yet supported on ROCm.
+        routing_weights = torch.softmax(gating_output,
+                                        dim=-1,
+                                        dtype=torch.float32)
+        topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
+    else:
+        import aphrodite._moe_C as moe_kernels
+
+        topk_weights = torch.empty(M,
+                                   topk,
+                                   dtype=torch.float32,
+                                   device=hidden_states.device)
+        topk_ids = torch.empty(M,
+                               topk,
+                               dtype=torch.int32,
+                               device=hidden_states.device)
+        token_expert_indicies = torch.empty(M,
+                                            topk,
+                                            dtype=torch.int32,
+                                            device=hidden_states.device)
+        moe_kernels.topk_softmax(
+            topk_weights,
+            topk_ids,
+            token_expert_indicies,
+            gating_output.float(),  # TODO: Optimize this.
+        )
+        del token_expert_indicies  # Not used. Will be used in the future.
+    if renormalize:
+        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+
     config = {
         'BLOCK_SIZE_M': 64,
         'BLOCK_SIZE_N': 64,

+ 3 - 11
aphrodite/modeling/models/deepseek.py

@@ -26,7 +26,6 @@ from typing import Any, Dict, List, Optional, Tuple
 
 import torch
 from torch import nn
-import torch.nn.functional as F
 from transformers import PretrainedConfig
 
 from aphrodite.modeling.metadata import InputMetadata
@@ -173,19 +172,12 @@ class DeepseekMoE(nn.Module):
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
 
-        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
-        routing_weights, selected_experts = torch.topk(routing_weights,
-                                                       self.top_k,
-                                                       dim=-1)
-
-        if self.config.norm_topk_prob:
-            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
-
         final_hidden_states = fused_moe(hidden_states,
                                         self.w1,
                                         self.w2,
-                                        routing_weights,
-                                        selected_experts,
+                                        router_logits,
+                                        self.top_k,
+                                        renormalize=self.config.norm_topk_prob,
                                         inplace=True)
 
         if self.config.n_shared_experts is not None:

+ 3 - 9
aphrodite/modeling/models/mixtral.py

@@ -25,7 +25,6 @@
 from typing import List, Optional, Tuple
 
 import torch
-import torch.nn.functional as F
 
 from torch import nn
 from transformers import MixtralConfig
@@ -129,17 +128,12 @@ class MixtralMoE(nn.Module):
         # router_logits: (batch * sequence_length, n_experts)
         router_logits, _ = self.gate(hidden_states)
 
-        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
-        routing_weights, selected_experts = torch.topk(routing_weights,
-                                                       self.top_k,
-                                                       dim=-1)
-        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
-
         final_hidden_states = fused_moe(hidden_states,
                                         self.ws,
                                         self.w2s,
-                                        routing_weights,
-                                        selected_experts,
+                                        router_logits,
+                                        self.top_k,
+                                        renormalize=True,
                                         inplace=True)
 
         final_hidden_states = tensor_model_parallel_all_reduce(

+ 7 - 0
kernels/moe/moe_ops.cpp

@@ -0,0 +1,7 @@
+#include "moe_ops.h"
+
+#include <torch/extension.h>
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("topk_softmax", &topk_softmax, "Apply top-k softmax to the gating outputs.");
+}

+ 9 - 0
kernels/moe/moe_ops.h

@@ -0,0 +1,9 @@
+#pragma once
+
+#include <torch/extension.h>
+
+void topk_softmax(
+    torch::Tensor& topk_weights,
+    torch::Tensor& topk_indices,
+    torch::Tensor& token_expert_indices,
+    torch::Tensor& gating_output);

+ 491 - 0
kernels/moe/softmax.cu

@@ -0,0 +1,491 @@
+/*
+ * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
+ * Copyright (c) 2024, The PygmalionAI team.
+ * Copyright (c) 2024, The vLLM team.
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include <cub/cub.cuh>
+#include <cub/util_type.cuh>
+
+#include "../cuda_compat.h"
+
+namespace aphrodite {
+namespace moe {
+
+static constexpr int WARP_SIZE = 32;
+
+// Aligned array type
+template <typename T, int N, int Alignment = sizeof(T) * N>
+class alignas(Alignment) AlignedArray {
+  float data[N];
+};
+
+// We have our own implementation of softmax here so we can support transposing the output
+// in the softmax kernel when we extend this module to support expert-choice routing.
+
+template <int TPB>
+__launch_bounds__(TPB) __global__
+    void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
+{
+    using BlockReduce = cub::BlockReduce<float, TPB>;
+    __shared__ typename BlockReduce::TempStorage tmpStorage;
+
+    __shared__ float normalizing_factor;
+    __shared__ float float_max;
+
+    const int thread_row_offset = blockIdx.x * num_cols;
+
+    cub::Sum sum;
+    float threadData(-FLT_MAX);
+
+    // Don't touch finished rows.
+    if ((finished != nullptr) && finished[blockIdx.x])
+    {
+        return;
+    }
+
+    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
+    {
+        const int idx = thread_row_offset + ii;
+        threadData = max(static_cast<float>(input[idx]), threadData);
+    }
+
+    const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
+    if (threadIdx.x == 0)
+    {
+        float_max = maxElem;
+    }
+    __syncthreads();
+
+    threadData = 0;
+
+    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
+    {
+        const int idx = thread_row_offset + ii;
+        threadData += exp((static_cast<float>(input[idx]) - float_max));
+    }
+
+    const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
+
+    if (threadIdx.x == 0)
+    {
+        normalizing_factor = 1.f / Z;
+    }
+    __syncthreads();
+
+    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
+    {
+        const int idx = thread_row_offset + ii;
+        const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
+        output[idx] = val;
+    }
+}
+
+template <int TPB>
+__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
+    int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
+{
+
+    using cub_kvp = cub::KeyValuePair<int, float>;
+    using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
+    __shared__ typename BlockReduce::TempStorage tmpStorage;
+
+    cub_kvp thread_kvp;
+    cub::ArgMax arg_max;
+
+    const int num_rows = gridDim.x;
+    const int block_row = blockIdx.x;
+
+    const bool row_is_active = finished ? !finished[block_row] : true;
+    const int thread_read_offset = blockIdx.x * num_experts;
+    for (int k_idx = 0; k_idx < k; ++k_idx)
+    {
+        thread_kvp.key = 0;
+        thread_kvp.value = -1.f; // This is OK because inputs are probabilities
+
+        cub_kvp inp_kvp;
+        for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
+        {
+            const int idx = thread_read_offset + expert;
+            inp_kvp.key = expert;
+            inp_kvp.value = inputs_after_softmax[idx];
+
+            for (int prior_k = 0; prior_k < k_idx; ++prior_k)
+            {
+                const int prior_winning_expert = indices[k * block_row + prior_k];
+
+                if (prior_winning_expert == expert)
+                {
+                    inp_kvp = thread_kvp;
+                }
+            }
+
+            thread_kvp = arg_max(inp_kvp, thread_kvp);
+        }
+
+        const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
+        if (threadIdx.x == 0)
+        {
+            // Ignore experts the node isn't responsible for with expert parallelism
+            const int expert = result_kvp.key;
+            const bool node_uses_expert = expert >= start_expert && expert < end_expert;
+            const bool should_process_row = row_is_active && node_uses_expert;
+
+            const int idx = k * block_row + k_idx;
+            output[idx] = result_kvp.value;
+            indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
+            assert(indices[idx] >= 0);
+            source_rows[idx] = k_idx * num_rows + block_row;
+        }
+        __syncthreads();
+    }
+}
+
+// Top-K
+template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
+__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
+    void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
+        int* source_rows, const int k, const int start_expert, const int end_expert)
+{
+    // We begin by enforcing compile time assertions and setting up compile time constants.
+    static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
+    static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
+    static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
+    static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
+
+    // Number of bytes each thread pulls in per load
+    static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
+    static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
+    static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
+    static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
+
+  // more compile-time assertions based on the previous section
+  static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
+  static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
+  static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
+  static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
+
+  static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
+  static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
+  static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
+
+  static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elts per warp");
+
+  // let's finally compute runtime variables
+
+  // compute CTA and warp rows. we pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
+  // each block processes a chunk of rows. Start by computing the start row for each block.
+  const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
+
+  // now, using the base row per thread block, compute the base row per warp
+  const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
+
+  // the threads in a warp are split into sub-groups that will work in a row.
+  // compute row offset for each thread sub-group
+  const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
+  const int thread_row = warp_base_row + thread_row_in_warp;
+
+  // threads with indices out of bounds should early exit here
+  if (thread_row >= num_rows)
+  {
+    return;
+  }
+  const bool row_is_active = finished ? !finished[thread_row] : true;
+
+  // finally start setting up the read pointers for each thread.
+  // first, each thread jumps to the start of the row it'll read
+  const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
+
+  // now we compute the group each thread belongs to in order to determine the
+  // first column to start loads
+  const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
+  const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
+  const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
+
+  // determine the pointer type to use to read in the data depending on the
+  // BYTES_PER_LDG template parameter
+  // this should support all powers of 2 up to 16
+  // NOTE: the original TensorRT-LLM implementation uses CUTLASS aligned arrays here
+  // we define our own aligned array and use it here to avoid using CUTLASS
+  using AccessType = AlignedArray<float, ELTS_PER_LDG>;
+
+  // finally, we put in the data from global memory
+  float row_chunk[VPT];
+  AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(row_chunk);
+  const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
+#pragma unroll
+  for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
+  {
+    row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
+  }
+
+  // first, we perform a max reduce within the thread. we can do the max in fp16
+  // safely ( i think) and just convert to float afterwards for the exp + sum reduction
+  float thread_max = row_chunk[0];
+#pragma unroll
+  for (int ii = 1; ii < VPT; ++ii)
+  {
+    thread_max = max(thread_max, row_chunk[ii]);
+  }
+
+  // now, we find the max within the thread group and distribute among the threads. we use the butterfly
+  // all-reduce algorithm
+#pragma unroll
+  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
+  {
+    thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
+  }
+
+  // from this point, thread max in all the threads have the max within the row.
+  // now, we subtract the max from each element in the thread and take the exp.
+  // we also compute the thread local sum
+  float row_sum = 0;
+#pragma unroll
+  for (int ii = 0; ii < VPT; ++ii)
+  {
+    row_chunk[ii] = exp(row_chunk[ii] - thread_max);
+    row_sum += row_chunk[ii];
+  }
+
+  // now we perform a sum reduction within the thread group
+  // we use the butterfly all-reduce algorithm
+#pragma unroll
+  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
+  {
+    row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
+  }
+
+  // from this point, all threads have the max and the sum for their rows in the
+  // thread_max and thread_sum variables respectively
+  // finally, we can scale the rows for the softmax. technically, for top-k gating
+  // we don't need to compute the entire softmax row. we can likely look at the
+  // maxes and only compute for the top-k values in the row.
+  // this kernel will likely not be a bottleneck
+  const float reciprocal_row_sum = 1.f / row_sum;
+#pragma unroll
+  for (int ii = 0; ii < VPT; ++ii)
+  {
+    row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
+  }
+
+  // now, softmax_res contains the softmax of the row chunk. now, let's find the
+  // top-k elements in each row, along with the max index
+  int start_col = first_elt_read_by_thread;
+  static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
+
+  for (int k_idx = 0; k_idx < k; ++k_idx)
+  {
+    // first each thread does the local argmax
+    float max_val = row_chunk[0];
+    int expert = start_col;
+#pragma unroll
+    for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
+    {
+#pragma unroll
+      for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
+      {
+        float val = row_chunk[ldg * ELTS_PER_LDG + ii];
+        // no check on the experts here since columns with the smallest index are processed
+        // first and only updated if > (not >=)
+        if (val > max_val)
+        {
+          max_val = val;
+          expert = col + ii;
+        }
+      }
+  }
+
+  // now pwe perform the argmax reduce. we use the butterfly pattern again so threads reach
+  // consensus about the max. this will be useful for K > 1 so that the threads can agree on "who"
+  // had the max value. that thread can then blank out their max with -inf and the warp can run
+  // more iterations
+#pragma unroll
+  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
+  {
+    float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
+    int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
+
+    // we want lower indices to "win" in every thread so we break ties this way
+    if (other_max > max_val || (other_max == max_val && other_expert < expert))
+    {
+      max_val = other_max;
+      expert = other_expert;
+    }
+  }
+
+  // write the max for this k iteration to global memory
+  if (thread_group_idx == 0)
+  {
+    // add a guard to ignore experts not included by this node
+    const bool node_uses_expert = expert >= start_expert && expert < end_expert;
+    const bool should_process_row = row_is_active && node_uses_expert;
+
+    // this lead thread from each sub-group will write out the final results to global memory
+    // (This will be a single) thread per row of the input/output matrices
+    const int idx = k * thread_row + k_idx;
+    output[idx] = max_val;
+    indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
+    source_rows[idx] = k_idx * num_rows + thread_row;
+  }
+
+  // finally, we clear the value in the thread with the current max if there is another iteration
+  // to run
+  if (k_idx + 1 < k)
+  {
+    const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
+    const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
+
+    // only the thread in the group which produced the max will reset the "winning"
+    // value to -inf
+    if (thread_group_idx == thread_to_clear_in_group)
+    {
+      const int offset_for_expert = expert % ELTS_PER_LDG;
+      // safe to set to any negative value since row_chunk values must be between 0 and 1
+      row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
+    }
+  }
+}
+}
+
+namespace detail
+{
+// constructs some constants needed to partition the work across threads at compile time
+template <int EXPERTS, int BYTES_PER_LDG>
+struct TopkConstants
+{
+  static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
+  static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
+  static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
+  static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
+  static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
+  static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
+};
+} // namespace detail
+
+template <int EXPERTS, int WARPS_PER_TB>
+void topkGatingSoftmaxLauncherHelper(
+  const float* input, const bool* finished, float* output, int* indices, int* source_row,
+  const int num_rows, const int k, const int start_expert, const int end_expert,
+  cudaStream_t stream)
+{
+  static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
+
+  static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
+  using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
+  static constexpr int VPT = Constants::VPT;
+  static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
+  const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
+  const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
+
+  dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
+  topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
+    input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
+}
+
+#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB)             \
+  topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
+    gating_output, nullptr, topk_weights, topk_indices,       \
+    token_expert_indices, num_tokens, topk, 0, num_experts,   \
+    stream);
+
+void topkGatingSoftmaxKernelLauncher(
+  const float* gating_output,
+  float* topk_weights,
+  int* topk_indices,
+  int* token_expert_indices,
+  float* softmax_workspace,
+  const int num_tokens,
+  const int num_experts,
+  const int topk,
+  cudaStream_t stream) {
+  static constexpr int WARPS_PER_TB = 4;
+  switch (num_experts) {
+    case 1:
+      LAUNCH_SOFTMAX(1, WARPS_PER_TB);
+      break;
+    case 2:
+      LAUNCH_SOFTMAX(2, WARPS_PER_TB);
+      break;
+    case 4:
+      LAUNCH_SOFTMAX(4, WARPS_PER_TB);
+      break;
+    case 8:
+      LAUNCH_SOFTMAX(8, WARPS_PER_TB);
+      break;
+    case 16:
+      LAUNCH_SOFTMAX(16, WARPS_PER_TB);
+      break;
+    case 32:
+      LAUNCH_SOFTMAX(32, WARPS_PER_TB);
+      break;
+    case 64:
+      LAUNCH_SOFTMAX(64, WARPS_PER_TB);
+      break;
+    case 128:
+      LAUNCH_SOFTMAX(128, WARPS_PER_TB);
+      break;
+    case 256:
+      LAUNCH_SOFTMAX(256, WARPS_PER_TB);
+      break;
+    default: {
+      TORCH_CHECK(softmax_workspace != nullptr,
+            "softmax_workspace must be provided for num_experts that aren't a power of 2.");
+      static constexpr int TPB = 256;
+      moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
+        gating_output, nullptr, softmax_workspace, num_experts);
+      moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
+        softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
+        num_experts, topk, 0, num_experts);
+    }
+  }
+}
+
+} // namespace moe
+} // namespace aphrodite
+
+void topk_softmax(
+  torch::Tensor& topk_weights,
+  torch::Tensor& topk_indices,
+  torch::Tensor& token_expert_indices,
+  torch::Tensor& gating_output)
+{
+  const int num_experts = gating_output.size(-1);
+  const int num_tokens = gating_output.numel() / num_experts;
+  const int topk = topk_weights.size(-1);
+
+  const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
+  const bool needs_workspace = !is_pow_2 || num_experts > 256;
+  const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
+
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
+  aphrodite::moe::topkGatingSoftmaxKernelLauncher(
+    gating_output.data_ptr<float>(),
+    topk_weights.data_ptr<float>(),
+    topk_indices.data_ptr<int>(),
+    token_expert_indices.data_ptr<int>(),
+    softmax_workspace.data_ptr<float>(),
+    num_tokens,
+    num_experts,
+    topk,
+    stream);
+}

+ 11 - 1
setup.py

@@ -303,7 +303,17 @@ if _is_cuda():
     aphrodite_extension_sources.append("kernels/quantization/quip/origin_order.cu")
     aphrodite_extension_sources.append("kernels/quantization/marlin/marlin_cuda_kernel.cu")
     aphrodite_extension_sources.append("kernels/all_reduce/custom_all_reduce.cu")
-
+    
+    ext_modules.append(
+        CUDAExtension(
+            name="aphrodite._moe_C",
+            sources=glob("kernels/moe/*.cu") + glob("kernels/moe/*.cpp"),
+            extra_compile_args={
+                "cxx": CXX_FLAGS,
+                "nvcc": NVCC_FLAGS,
+            },
+        ))
+    
 aphrodite_extension = CUDAExtension(
     name="aphrodite._C",
     sources=aphrodite_extension_sources,