123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500 |
- #include <torch/extension.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <cub/cub.cuh>
- #include <cub/util_type.cuh>
- namespace aphrodite {
- namespace moe {
- static constexpr int WARP_SIZE = 32;
- template <
- typename T,
- int N,
- int Alignment = sizeof(T) * N
- >
- class alignas(Alignment) AlignedArray {
- float data[N];
- };
- template <int TPB>
- __launch_bounds__(TPB) __global__
- void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
- {
- using BlockReduce = cub::BlockReduce<float, TPB>;
- __shared__ typename BlockReduce::TempStorage tmpStorage;
- __shared__ float normalizing_factor;
- __shared__ float float_max;
- const int thread_row_offset = blockIdx.x * num_cols;
- cub::Sum sum;
- float threadData(-FLT_MAX);
- if ((finished != nullptr) && finished[blockIdx.x])
- {
- return;
- }
- for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
- {
- const int idx = thread_row_offset + ii;
- threadData = max(static_cast<float>(input[idx]), threadData);
- }
- const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
- if (threadIdx.x == 0)
- {
- float_max = maxElem;
- }
- __syncthreads();
- threadData = 0;
- for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
- {
- const int idx = thread_row_offset + ii;
- threadData += exp((static_cast<float>(input[idx]) - float_max));
- }
- const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
- if (threadIdx.x == 0)
- {
- normalizing_factor = 1.f / Z;
- }
- __syncthreads();
- for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
- {
- const int idx = thread_row_offset + ii;
- const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
- output[idx] = val;
- }
- }
- template <int TPB>
- __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
- int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
- {
- using cub_kvp = cub::KeyValuePair<int, float>;
- using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
- __shared__ typename BlockReduce::TempStorage tmpStorage;
- cub_kvp thread_kvp;
- cub::ArgMax arg_max;
- const int num_rows = gridDim.x;
- const int block_row = blockIdx.x;
- const bool row_is_active = finished ? !finished[block_row] : true;
- const int thread_read_offset = blockIdx.x * num_experts;
- for (int k_idx = 0; k_idx < k; ++k_idx)
- {
- thread_kvp.key = 0;
- thread_kvp.value = -1.f;
- cub_kvp inp_kvp;
- for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
- {
- const int idx = thread_read_offset + expert;
- inp_kvp.key = expert;
- inp_kvp.value = inputs_after_softmax[idx];
- for (int prior_k = 0; prior_k < k_idx; ++prior_k)
- {
- const int prior_winning_expert = indices[k * block_row + prior_k];
- if (prior_winning_expert == expert)
- {
- inp_kvp = thread_kvp;
- }
- }
- thread_kvp = arg_max(inp_kvp, thread_kvp);
- }
- const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
- if (threadIdx.x == 0)
- {
- const int expert = result_kvp.key;
- const bool node_uses_expert = expert >= start_expert && expert < end_expert;
- const bool should_process_row = row_is_active && node_uses_expert;
- const int idx = k * block_row + k_idx;
- output[idx] = result_kvp.value;
- indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
- assert(indices[idx] >= 0);
- source_rows[idx] = k_idx * num_rows + block_row;
- }
- __syncthreads();
- }
- }
- template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
- __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
- void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
- int* source_rows, const int k, const int start_expert, const int end_expert)
- {
- static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
- static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
- static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
- static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
- static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
- static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
- static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
- static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
- static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
- static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
- static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
- static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
- static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
- static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
- static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
- static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
- const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
- const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
- const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
- const int thread_row = warp_base_row + thread_row_in_warp;
- if (thread_row >= num_rows)
- {
- return;
- }
- const bool row_is_active = finished ? !finished[thread_row] : true;
- const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
- const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
- const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
- const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
- using AccessType = AlignedArray<float, ELTS_PER_LDG>;
- float row_chunk[VPT];
- AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
- const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
- #pragma unroll
- for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
- {
- row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
- }
- float thread_max = row_chunk[0];
- #pragma unroll
- for (int ii = 1; ii < VPT; ++ii)
- {
- thread_max = max(thread_max, row_chunk[ii]);
- }
- #pragma unroll
- for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
- {
- thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
- }
- float row_sum = 0;
- #pragma unroll
- for (int ii = 0; ii < VPT; ++ii)
- {
- row_chunk[ii] = expf(row_chunk[ii] - thread_max);
- row_sum += row_chunk[ii];
- }
- #pragma unroll
- for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
- {
- row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
- }
- const float reciprocal_row_sum = 1.f / row_sum;
- #pragma unroll
- for (int ii = 0; ii < VPT; ++ii)
- {
- row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
- }
- int start_col = first_elt_read_by_thread;
- for (int k_idx = 0; k_idx < k; ++k_idx)
- {
- float max_val = row_chunk[0];
- int expert = start_col;
- #pragma unroll
- for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
- {
- #pragma unroll
- for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
- {
- float val = row_chunk[ldg * ELTS_PER_LDG + ii];
- if (val > max_val)
- {
- max_val = val;
- expert = col + ii;
- }
- }
- }
- #pragma unroll
- for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
- {
- float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
- int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
- if (other_max > max_val || (other_max == max_val && other_expert < expert))
- {
- max_val = other_max;
- expert = other_expert;
- }
- }
- if (thread_group_idx == 0)
- {
- const bool node_uses_expert = expert >= start_expert && expert < end_expert;
- const bool should_process_row = row_is_active && node_uses_expert;
- const int idx = k * thread_row + k_idx;
- output[idx] = max_val;
- indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
- source_rows[idx] = k_idx * num_rows + thread_row;
- }
- if (k_idx + 1 < k)
- {
- const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
- const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
- if (thread_group_idx == thread_to_clear_in_group)
- {
- const int offset_for_expert = expert % ELTS_PER_LDG;
- row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
- }
- }
- }
- }
- namespace detail
- {
- template <int EXPERTS, int BYTES_PER_LDG>
- struct TopkConstants
- {
- static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
- static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
- static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
- static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
- static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
- static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
- };
- }
- template <int EXPERTS, int WARPS_PER_TB>
- void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
- int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
- {
- static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
- static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
- using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
- static constexpr int VPT = Constants::VPT;
- static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
- const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
- const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
- dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
- topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
- input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
- }
- topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
- gating_output, nullptr, topk_weights, topk_indicies, \
- token_expert_indices, num_tokens, topk, 0, num_experts, \
- stream);
- void topkGatingSoftmaxKernelLauncher(
- const float* gating_output,
- float* topk_weights,
- int* topk_indicies,
- int* token_expert_indices,
- float* softmax_workspace,
- const int num_tokens,
- const int num_experts,
- const int topk,
- cudaStream_t stream) {
- static constexpr int WARPS_PER_TB = 4;
- switch (num_experts) {
- case 1:
- break;
- case 2:
- break;
- case 4:
- break;
- case 8:
- break;
- case 16:
- break;
- case 32:
- break;
- case 64:
- break;
- case 128:
- break;
- case 256:
- break;
- default: {
- TORCH_CHECK(softmax_workspace != nullptr,
- "softmax_workspace must be provided for num_experts that are not a power of 2.");
- static constexpr int TPB = 256;
- moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
- gating_output, nullptr, softmax_workspace, num_experts);
- moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
- softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
- num_experts, topk, 0, num_experts);
- }
- }
- }
- }
- }
- void topk_softmax(
- torch::Tensor& topk_weights,
- torch::Tensor& topk_indices,
- torch::Tensor& token_expert_indices,
- torch::Tensor& gating_output)
- {
- const int num_experts = gating_output.size(-1);
- const int num_tokens = gating_output.numel() / num_experts;
- const int topk = topk_weights.size(-1);
- const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
- const bool needs_workspace = !is_pow_2 || num_experts > 256;
- const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
- const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
- aphrodite::moe::topkGatingSoftmaxKernelLauncher(
- gating_output.data_ptr<float>(),
- topk_weights.data_ptr<float>(),
- topk_indices.data_ptr<int>(),
- token_expert_indices.data_ptr<int>(),
- softmax_workspace.data_ptr<float>(),
- num_tokens,
- num_experts,
- topk,
- stream);
- }