123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 |
- """Fused MoE kernel."""
- import functools
- import json
- import os
- from typing import Any, Dict, Optional, Tuple
- import torch
- import triton
- import triton.language as tl
- from loguru import logger
- from aphrodite._C import ops
- from aphrodite.common.utils import is_hip
- @triton.jit
- def fused_moe_kernel(
- # Pointers to matrices
- a_ptr,
- b_ptr,
- c_ptr,
- a_scale_ptr,
- b_scale_ptr,
- topk_weights_ptr,
- sorted_token_ids_ptr,
- expert_ids_ptr,
- num_tokens_post_padded_ptr,
- # Matrix dimensions
- N,
- K,
- EM,
- num_valid_tokens,
- # The stride variables represent how much to increase the ptr by when
- # moving by 1 element in a particular dimension. E.g. `stride_am` is
- # how much to increase `a_ptr` by to get the element one row down
- # (A has M rows).
- stride_am,
- stride_ak,
- stride_be,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- # Meta-parameters
- BLOCK_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr,
- GROUP_SIZE_M: tl.constexpr,
- MUL_ROUTED_WEIGHT: tl.constexpr,
- top_k: tl.constexpr,
- compute_type: tl.constexpr,
- use_fp8: tl.constexpr,
- ):
- """
- Implements the fused computation for a Mixture of Experts (MOE) using
- token and expert matrices.
- Key Parameters:
- - A: The input tensor representing tokens with shape (*, K), where '*' can
- be any shape representing batches and K is the feature dimension of
- each token.
- - B: The stacked MOE weight tensor with shape (E, N, K), where E is
- the number of experts, K is the input feature dimension, and N is
- the output feature dimension.
- - C: The output cache tensor with shape (M, topk, N), where M is the
- total number of tokens post padding, topk is the number of times
- each token is repeated, and N is the output feature dimension.
- - sorted_token_ids: A tensor containing the sorted indices of tokens,
- repeated topk times and arranged by the expert index they are
- assigned to.
- - expert_ids: A tensor containing the indices of the expert for each
- block. It determines which expert matrix from B should be used for
- each block in A.
- This kernel performs the multiplication of a token by its corresponding
- expert matrix as determined by `expert_ids`. The sorting of
- `sorted_token_ids` by expert index and padding ensures divisibility by
- BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
- multiplication across different blocks processed by the same expert.
- """
- # -----------------------------------------------------------
- # Map program ids `pid` to the block of C it should compute.
- # This is done in a grouped ordering to promote L2 data reuse.
- pid = tl.program_id(axis=0)
- num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
- # ----------------------------------------------------------
- # Create pointers for the first blocks of A and B.
- # We will advance this pointer as we move in the K direction
- # and accumulate
- # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
- # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
- num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
- if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
- return
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
- token_mask = offs_token < num_valid_tokens
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
- offs_k[None, :] * stride_ak)
- off_experts = tl.load(expert_ids_ptr + pid_m)
- b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
- offs_bn[None, :] * stride_bn)
- if use_fp8:
- a_scale = tl.load(a_scale_ptr)
- b_scale = tl.load(b_scale_ptr + off_experts)
- # -----------------------------------------------------------
- # Iterate to compute a block of the C matrix.
- # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
- # of fp32 values for higher accuracy.
- # `accumulator` will be converted back to fp16 after the loop.
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- # Load the next block of A and B, generate a mask by checking the
- # K dimension.
- a = tl.load(a_ptrs,
- mask=token_mask[:, None] &
- (offs_k[None, :] < K - k * BLOCK_SIZE_K),
- other=0.0)
- b = tl.load(b_ptrs,
- mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
- other=0.0)
- # We accumulate along the K dimension.
- if use_fp8:
- accumulator = tl.dot(a, b, acc=accumulator)
- else:
- accumulator += tl.dot(a, b)
- # Advance the ptrs to the next K block.
- a_ptrs += BLOCK_SIZE_K * stride_ak
- b_ptrs += BLOCK_SIZE_K * stride_bk
- if MUL_ROUTED_WEIGHT:
- moe_weight = tl.load(topk_weights_ptr + offs_token,
- mask=token_mask,
- other=0)
- accumulator = accumulator * moe_weight[:, None]
- if use_fp8:
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
- else:
- accumulator = accumulator.to(compute_type)
- # -----------------------------------------------------------
- # Write back the block of the output
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
- None, :]
- c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
- tl.store(c_ptrs, accumulator, mask=c_mask)
- def moe_align_block_size(
- topk_ids: torch.Tensor, block_size: int,
- num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Aligns the token distribution across experts to be compatible with block
- size for matrix multiplication.
- Parameters:
- - topk_ids: A tensor of shape [total_tokens, top_k] representing the
- top-k expert indices for each token.
- - block_size: The block size used in block matrix multiplication.
- - num_experts: The total number of experts.
- Returns:
- - sorted_token_ids: A tensor containing the sorted token indices according
- to their allocated expert.
- - expert_ids: A tensor indicating the assigned expert index for each block.
- - num_tokens_post_padded: The total number of tokens after padding,
- ensuring divisibility by block_size.
- This function pads the number of tokens that each expert needs to process
- so that it is divisible by block_size.
- Padding ensures that during block matrix multiplication, the dimensions
- align correctly.
- Example:
- Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
- block_size = 4, and num_experts = 4:
- - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
- with each expert needing to process 3 tokens.
- - As block_size is 4, we pad 1 token for each expert.
- - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- - Then append padding tokens [12, 12, 12, 12] for each block.
- - After sorting by expert index, we obtain token_ids
- [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
- Tokens 12 are non-existent (padding) and are ignored in
- the subsequent matrix multiplication.
- - The padding ensures that the total number of tokens is now divisible
- by block_size for proper block matrix operations.
- """
- sorted_ids = torch.empty(
- (topk_ids.numel() + num_experts * (block_size - 1), ),
- dtype=torch.int32,
- device=topk_ids.device)
- expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
- dtype=torch.int32,
- device=topk_ids.device)
- sorted_ids.fill_(topk_ids.numel())
- num_tokens_post_pad = torch.empty((1),
- dtype=torch.int32,
- device=topk_ids.device)
- ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
- expert_ids, num_tokens_post_pad)
- return sorted_ids, expert_ids, num_tokens_post_pad
- def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
- A_scale: Optional[torch.Tensor],
- B_scale: Optional[torch.Tensor],
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
- sorted_token_ids: torch.Tensor,
- expert_ids: torch.Tensor,
- num_tokens_post_padded: torch.Tensor,
- mul_routed_weight: bool, top_k: int,
- config: Dict[str, Any], compute_type: tl.dtype,
- use_fp8: bool) -> None:
- assert topk_weights.stride(1) == 1
- assert sorted_token_ids.stride(0) == 1
- if not use_fp8:
- assert A_scale is None
- assert B_scale is None
- else:
- A, A_scale = ops.scaled_fp8_quant(A, A_scale)
- assert B_scale is not None
- grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
- 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
- fused_moe_kernel[grid](
- A,
- B,
- C,
- A_scale,
- B_scale,
- topk_weights,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- B.shape[1],
- B.shape[2],
- sorted_token_ids.shape[0],
- topk_ids.numel(),
- A.stride(0),
- A.stride(1),
- B.stride(0),
- B.stride(2),
- B.stride(1),
- C.stride(1),
- C.stride(2),
- MUL_ROUTED_WEIGHT=mul_routed_weight,
- top_k=top_k,
- compute_type=compute_type,
- use_fp8=use_fp8,
- **config,
- )
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
- device_name = torch.cuda.get_device_name().replace(" ", "_")
- dtype_selector = "" if not dtype else f",dtype={dtype}"
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
- @functools.lru_cache
- def get_moe_configs(E: int, N: int,
- dtype: Optional[str]) -> Optional[Dict[int, Any]]:
- """
- Return optimized configurations for the fused MoE kernel.
- The return value will be a dictionary that maps an irregular grid of
- batch sizes to configurations of the fused_moe kernel. To evaluate the
- kernel on a given batch size bs, the closest batch size in the grid should
- be picked and the associated configuration chosen to invoke the kernel.
- """
- # First look up if an optimized configuration is available in the configs
- # directory
- json_file_name = get_config_file_name(E, N, dtype)
- config_file_path = os.path.join(
- os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
- if os.path.exists(config_file_path):
- with open(config_file_path) as f:
- logger.info(
- f"Using configuration from {config_file_path} for MoE layer.")
- # If a configuration has been found, return it
- return {int(key): val for key, val in json.load(f).items()}
- # If no optimized configuration is available, we will use the default
- # configuration
- return None
- def fused_moe(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- renormalize: bool,
- inplace: bool = False,
- override_config: Optional[Dict[str, Any]] = None,
- use_fp8: bool = False,
- w1_scale: Optional[torch.Tensor] = None,
- w2_scale: Optional[torch.Tensor] = None,
- a1_scale: Optional[torch.Tensor] = None,
- a2_scale: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """
- This function computes a Mixture of Experts (MoE) layer using two sets of
- weights, w1 and w2, and top-k gating mechanism.
- Parameters:
- - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- - w1 (torch.Tensor): The first set of expert weights.
- - w2 (torch.Tensor): The second set of expert weights.
- - gating_output (torch.Tensor): The output of the gating operation
- (before softmax).
- - topk (int): The number of top-k experts to select.
- - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- - inplace (bool): If True, perform the operation in-place.
- Defaults to False.
- - override_config (Optional[Dict[str, Any]]): Optional override
- for the kernel configuration.
- - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
- products for w1 and w2. Defaults to False.
- - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
- w1.
- - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
- w2.
- Returns:
- - torch.Tensor: The output tensor after applying the MoE layer.
- """
- # Check constraints.
- assert hidden_states.shape[0] == gating_output.shape[0], (
- "Number of tokens mismatch")
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
- assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
- assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
- assert w1.is_contiguous(), "Expert weights1 must be contiguous"
- assert w2.is_contiguous(), "Expert weights2 must be contiguous"
- assert hidden_states.dtype in [
- torch.float32, torch.float16, torch.bfloat16
- ]
- M, _ = hidden_states.shape
- E, N, _ = w1.shape
- if is_hip():
- # The MoE kernels are not yet supported on ROCm.
- routing_weights = torch.softmax(gating_output,
- dim=-1,
- dtype=torch.float32)
- topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
- else:
- import aphrodite._moe_C as moe_kernels
- topk_weights = torch.empty(M,
- topk,
- dtype=torch.float32,
- device=hidden_states.device)
- topk_ids = torch.empty(M,
- topk,
- dtype=torch.int32,
- device=hidden_states.device)
- token_expert_indicies = torch.empty(M,
- topk,
- dtype=torch.int32,
- device=hidden_states.device)
- moe_kernels.topk_softmax(
- topk_weights,
- topk_ids,
- token_expert_indicies,
- gating_output.float(), # TODO(woosuk): Optimize this.
- )
- del token_expert_indicies # Not used. Will be used in the future.
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- if override_config:
- config = override_config
- else:
- # First try to load optimal config from the file
- configs = get_moe_configs(E, w2.shape[2],
- "float8" if use_fp8 else None)
- if configs:
- # If an optimal configuration map has been found, look up the
- # optimal config
- config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
- else:
- # Else use the default config
- config = {
- 'BLOCK_SIZE_M': 64,
- 'BLOCK_SIZE_N': 64,
- 'BLOCK_SIZE_K': 32,
- 'GROUP_SIZE_M': 8
- }
- if M <= E:
- config = {
- 'BLOCK_SIZE_M': 16,
- 'BLOCK_SIZE_N': 32,
- 'BLOCK_SIZE_K': 64,
- 'GROUP_SIZE_M': 1
- }
- intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
- device=hidden_states.device,
- dtype=hidden_states.dtype)
- intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
- device=hidden_states.device,
- dtype=hidden_states.dtype)
- intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
- device=hidden_states.device,
- dtype=hidden_states.dtype)
- sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
- topk_ids, config['BLOCK_SIZE_M'], E)
- invoke_fused_moe_kernel(hidden_states,
- w1,
- intermediate_cache1,
- a1_scale,
- w1_scale,
- topk_weights,
- topk_ids,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- False,
- topk_ids.shape[1],
- config,
- compute_type=tl.float16,
- use_fp8=use_fp8)
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
- invoke_fused_moe_kernel(intermediate_cache2,
- w2,
- intermediate_cache3,
- a2_scale,
- w2_scale,
- topk_weights,
- topk_ids,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- True,
- 1,
- config,
- compute_type=tl.float16,
- use_fp8=use_fp8)
- if inplace:
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
- dim=1,
- out=hidden_states)
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
- dim=1)
|