"""Fused MoE kernel.""" import functools import json import os from typing import Any, Dict, Optional, Tuple import torch import triton import triton.language as tl from loguru import logger from aphrodite import _custom_ops as ops from aphrodite.platforms import current_platform APHRODITE_FUSED_MOE_CHUNK_SIZE = int( os.getenv("APHRODITE_FUSED_MOE_CHUNK_SIZE", "65536")) @triton.jit def fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_bse, stride_bsn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) off_experts = tl.load(expert_ids_ptr + pid_m) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. Returns: - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. Example: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. - As block_size is 4, we pad 1 token for each expert. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - Then append padding tokens [12, 12, 12, 12] for each block. - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = torch.empty((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: Dict[str, Any], compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None elif use_int8_w8a16: assert B_scale is not None else: assert A_scale is None assert B_scale is None grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) fused_moe_kernel[grid]( A, B, C, A_scale, B_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.shape[1], B.shape[2], sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, **config, ) def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" @functools.lru_cache def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. The return value will be a dictionary that maps an irregular grid of batch sizes to configurations of the fused_moe kernel. To evaluate the kernel on a given batch size bs, the closest batch size in the grid should be picked and the associated configuration chosen to invoke the kernel. """ # First look up if an optimized configuration is available in the configs # directory json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info(f"Using configuration from {config_file_path} " "for MoE layer.") # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration return None def get_default_config( M: int, E: int, N: int, K: int, topk: int, dtype: Optional[str], ) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } if M <= E: config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1 } return config def try_get_optimal_moe_config( w1_shape: Tuple[int, ...], w2_shape: Tuple[int, ...], top_k: int, dtype: Optional[str], M: int, override_config: Optional[Dict[str, Any]] = None, ): if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape configs = get_moe_configs(E, N, dtype) if configs: # If an optimal configuration map has been found, look up the # optimal config config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) return config def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, ): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") M, _ = hidden_states.shape topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies, gating_output.float(), # TODO: Optimize this. ) del token_expert_indicies # Not used. Will be used in the future. if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids # This is used by the Deepseek-V2 model def grouped_topk(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs return "float32" return None def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] num_tokens, _ = hidden_states.shape E, N, _ = w1.shape # We execute the fused_moe kernel in chunks. CHUNK_SIZE = APHRODITE_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, dtype=hidden_states.dtype) get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, w2.shape, topk_ids.shape[1], config_dtype, override_config=override_config, ) config = get_config_func(M) intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) if inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape if tokens_in_chunk == 0: break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) invoke_fused_moe_kernel(curr_hidden_states, w1, intermediate_cache1, a1_scale, w1_scale, curr_topk_weights, curr_topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, False, topk_ids.shape[1], config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, w2_scale, curr_topk_weights, curr_topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, True, 1, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16) torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - inplace (bool): If True, perform the operation in-place. Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" if use_grouped_topk: assert num_expert_group is not None and topk_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk, renormalize, num_expert_group, topk_group) else: topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) return fused_experts(hidden_states, w1, w2, topk_weights, topk_ids, inplace=inplace, override_config=override_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale)