fused_moe.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """Fused MoE kernel."""
  2. import functools
  3. import json
  4. import os
  5. from typing import Any, Dict, Optional, Tuple
  6. import torch
  7. import triton
  8. import triton.language as tl
  9. from loguru import logger
  10. from aphrodite._C import ops
  11. from aphrodite.common.utils import is_hip
  12. @triton.jit
  13. def fused_moe_kernel(
  14. # Pointers to matrices
  15. a_ptr,
  16. b_ptr,
  17. c_ptr,
  18. topk_weights_ptr,
  19. sorted_token_ids_ptr,
  20. expert_ids_ptr,
  21. num_tokens_post_padded_ptr,
  22. # Matrix dimensions
  23. N,
  24. K,
  25. EM,
  26. num_valid_tokens,
  27. # The stride variables represent how much to increase the ptr by when
  28. # moving by 1 element in a particular dimension. E.g. `stride_am` is
  29. # how much to increase `a_ptr` by to get the element one row down
  30. # (A has M rows).
  31. stride_am,
  32. stride_ak,
  33. stride_be,
  34. stride_bk,
  35. stride_bn,
  36. stride_cm,
  37. stride_cn,
  38. # Meta-parameters
  39. BLOCK_SIZE_M: tl.constexpr,
  40. BLOCK_SIZE_N: tl.constexpr,
  41. BLOCK_SIZE_K: tl.constexpr,
  42. GROUP_SIZE_M: tl.constexpr,
  43. MUL_ROUTED_WEIGHT: tl.constexpr,
  44. top_k: tl.constexpr,
  45. compute_type: tl.constexpr,
  46. ):
  47. """
  48. Implements the fused computation for a Mixture of Experts (MOE) using
  49. token and expert matrices.
  50. Key Parameters:
  51. - A: The input tensor representing tokens with shape (*, K), where '*' can
  52. be any shape representing batches and K is the feature dimension of
  53. each token.
  54. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
  55. the number of experts, K is the input feature dimension, and N is
  56. the output feature dimension.
  57. - C: The output cache tensor with shape (M, topk, N), where M is the
  58. total number of tokens post padding, topk is the number of times
  59. each token is repeated, and N is the output feature dimension.
  60. - sorted_token_ids: A tensor containing the sorted indices of tokens,
  61. repeated topk times and arranged by the expert index they are
  62. assigned to.
  63. - expert_ids: A tensor containing the indices of the expert for each
  64. block. It determines which expert matrix from B should be used for
  65. each block in A.
  66. This kernel performs the multiplication of a token by its corresponding
  67. expert matrix as determined by `expert_ids`. The sorting of
  68. `sorted_token_ids` by expert index and padding ensures divisibility by
  69. BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
  70. multiplication across different blocks processed by the same expert.
  71. """
  72. # -----------------------------------------------------------
  73. # Map program ids `pid` to the block of C it should compute.
  74. # This is done in a grouped ordering to promote L2 data reuse.
  75. pid = tl.program_id(axis=0)
  76. num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
  77. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  78. num_pid_in_group = GROUP_SIZE_M * num_pid_n
  79. group_id = pid // num_pid_in_group
  80. first_pid_m = group_id * GROUP_SIZE_M
  81. group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  82. pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
  83. pid_n = (pid % num_pid_in_group) // group_size_m
  84. # ----------------------------------------------------------
  85. # Create pointers for the first blocks of A and B.
  86. # We will advance this pointer as we move in the K direction
  87. # and accumulate
  88. # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
  89. # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
  90. num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
  91. if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
  92. return
  93. offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  94. offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
  95. token_mask = offs_token < num_valid_tokens
  96. offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  97. offs_k = tl.arange(0, BLOCK_SIZE_K)
  98. a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
  99. offs_k[None, :] * stride_ak)
  100. off_experts = tl.load(expert_ids_ptr + pid_m)
  101. b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
  102. offs_bn[None, :] * stride_bn)
  103. # -----------------------------------------------------------
  104. # Iterate to compute a block of the C matrix.
  105. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
  106. # of fp32 values for higher accuracy.
  107. # `accumulator` will be converted back to fp16 after the loop.
  108. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  109. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  110. # Load the next block of A and B, generate a mask by checking the
  111. # K dimension.
  112. a = tl.load(a_ptrs,
  113. mask=token_mask[:, None] &
  114. (offs_k[None, :] < K - k * BLOCK_SIZE_K),
  115. other=0.0)
  116. b = tl.load(b_ptrs,
  117. mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
  118. other=0.0)
  119. # We accumulate along the K dimension.
  120. accumulator += tl.dot(a, b)
  121. # Advance the ptrs to the next K block.
  122. a_ptrs += BLOCK_SIZE_K * stride_ak
  123. b_ptrs += BLOCK_SIZE_K * stride_bk
  124. if MUL_ROUTED_WEIGHT:
  125. moe_weight = tl.load(topk_weights_ptr + offs_token,
  126. mask=token_mask,
  127. other=0)
  128. accumulator = accumulator * moe_weight[:, None]
  129. accumulator = accumulator.to(compute_type)
  130. # -----------------------------------------------------------
  131. # Write back the block of the output
  132. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  133. c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
  134. None, :]
  135. c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
  136. tl.store(c_ptrs, accumulator, mask=c_mask)
  137. def moe_align_block_size(
  138. topk_ids: torch.Tensor, block_size: int,
  139. num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  140. """
  141. Aligns the token distribution across experts to be compatible with block
  142. size for matrix multiplication.
  143. Parameters:
  144. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
  145. top-k expert indices for each token.
  146. - block_size: The block size used in block matrix multiplication.
  147. - num_experts: The total number of experts.
  148. Returns:
  149. - sorted_token_ids: A tensor containing the sorted token indices according
  150. to their allocated expert.
  151. - expert_ids: A tensor indicating the assigned expert index for each block.
  152. - num_tokens_post_padded: The total number of tokens after padding,
  153. ensuring divisibility by block_size.
  154. This function pads the number of tokens that each expert needs to process
  155. so that it is divisible by block_size.
  156. Padding ensures that during block matrix multiplication, the dimensions
  157. align correctly.
  158. Example:
  159. Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
  160. block_size = 4, and num_experts = 4:
  161. - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
  162. with each expert needing to process 3 tokens.
  163. - As block_size is 4, we pad 1 token for each expert.
  164. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
  165. - Then append padding tokens [12, 12, 12, 12] for each block.
  166. - After sorting by expert index, we obtain token_ids
  167. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
  168. Tokens 12 are non-existent (padding) and are ignored in
  169. the subsequent matrix multiplication.
  170. - The padding ensures that the total number of tokens is now divisible
  171. by block_size for proper block matrix operations.
  172. """
  173. sorted_ids = torch.empty(
  174. (topk_ids.numel() + num_experts * (block_size - 1), ),
  175. dtype=torch.int32,
  176. device=topk_ids.device)
  177. expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
  178. dtype=torch.int32,
  179. device=topk_ids.device)
  180. sorted_ids.fill_(topk_ids.numel())
  181. num_tokens_post_pad = torch.empty((1),
  182. dtype=torch.int32,
  183. device=topk_ids.device)
  184. ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
  185. expert_ids, num_tokens_post_pad)
  186. return sorted_ids, expert_ids, num_tokens_post_pad
  187. def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
  188. topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  189. sorted_token_ids: torch.Tensor,
  190. expert_ids: torch.Tensor,
  191. num_tokens_post_padded: torch.Tensor,
  192. mul_routed_weight: bool, top_k: int,
  193. config: Dict[str, Any]) -> None:
  194. assert topk_weights.stride(1) == 1
  195. assert sorted_token_ids.stride(0) == 1
  196. grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
  197. 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
  198. fused_moe_kernel[grid](
  199. A,
  200. B,
  201. C,
  202. topk_weights,
  203. sorted_token_ids,
  204. expert_ids,
  205. num_tokens_post_padded,
  206. B.shape[1],
  207. B.shape[2],
  208. sorted_token_ids.shape[0],
  209. topk_ids.numel(),
  210. A.stride(0),
  211. A.stride(1),
  212. B.stride(0),
  213. B.stride(2),
  214. B.stride(1),
  215. C.stride(1),
  216. C.stride(2),
  217. MUL_ROUTED_WEIGHT=mul_routed_weight,
  218. top_k=top_k,
  219. compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
  220. **config,
  221. )
  222. def fused_topk(
  223. gating_output: torch.Tensor,
  224. topk: int,
  225. renormalize: bool,
  226. ):
  227. """Compute top-k indice and weights from gating logits
  228. Args:
  229. gating_output (torch.Tensor): The output of the gating operation
  230. (before softmax).
  231. topk (int): The number of top-k experts to select.
  232. renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  233. """
  234. M = gating_output.shape[0]
  235. if is_hip():
  236. # The MoE kernels are not yet supported on ROCm.
  237. routing_weights = torch.softmax(gating_output,
  238. dim=-1,
  239. dtype=torch.float32)
  240. topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
  241. else:
  242. import aphrodite._moe_C as moe_kernels
  243. topk_weights = torch.empty(M,
  244. topk,
  245. dtype=torch.float32,
  246. device=gating_output.device)
  247. topk_ids = torch.empty(M,
  248. topk,
  249. dtype=torch.int32,
  250. device=gating_output.device)
  251. token_expert_indicies = torch.empty(M,
  252. topk,
  253. dtype=torch.int32,
  254. device=gating_output.device)
  255. moe_kernels.topk_softmax(
  256. topk_weights,
  257. topk_ids,
  258. token_expert_indicies,
  259. gating_output.float(), # TODO(woosuk): Optimize this.
  260. )
  261. del token_expert_indicies # Not used. Will be used in the future.
  262. if renormalize:
  263. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  264. return topk_weights, topk_ids
  265. def get_config_file_name(E: int, N: int) -> str:
  266. device_name = torch.cuda.get_device_name().replace(" ", "_")
  267. return f"E={E},N={N},device_name={device_name}.json"
  268. @functools.lru_cache
  269. def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
  270. """
  271. Return optimized configurations for the fused MoE kernel.
  272. The return value will be a dictionary that maps an irregular grid of
  273. batch sizes to configurations of the fused_moe kernel. To evaluate the
  274. kernel on a given batch size bs, the closest batch size in the grid should
  275. be picked and the associated configuration chosen to invoke the kernel.
  276. """
  277. # First look up if an optimized configuration is available in the configs
  278. # directory
  279. json_file_name = get_config_file_name(E, N)
  280. config_file_path = os.path.join(
  281. os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
  282. if os.path.exists(config_file_path):
  283. with open(config_file_path) as f:
  284. logger.info(
  285. f"Using configuration from {config_file_path} for MoE layer.")
  286. # If a configuration has been found, return it
  287. return {int(key): val for key, val in json.load(f).items()}
  288. # If no optimized configuration is available, we will use the default
  289. # configuration
  290. return None
  291. def fused_moe(
  292. hidden_states: torch.Tensor,
  293. w1: torch.Tensor,
  294. w2: torch.Tensor,
  295. gating_output: torch.Tensor,
  296. topk: int,
  297. renormalize: bool,
  298. inplace: bool = True,
  299. override_config: Optional[Dict[str, Any]] = None,
  300. ) -> torch.Tensor:
  301. """
  302. This function computes a Mixture of Experts (MoE) layer using two sets of
  303. weights, w1 and w2, and top-k gating mechanism.
  304. Parameters:
  305. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
  306. - w1 (torch.Tensor): The first set of expert weights.
  307. - w2 (torch.Tensor): The second set of expert weights.
  308. - gating_output (torch.Tensor): The output of the gating operation
  309. (before softmax).
  310. - topk (int): The number of top-k experts to select.
  311. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  312. - inplace (bool): If True, perform the operation in-place.
  313. Defaults to False.
  314. - override_config (Optional[Dict[str, Any]]): Optional override
  315. for the kernel configuration.
  316. Returns:
  317. - torch.Tensor: The output tensor after applying the MoE layer.
  318. """
  319. # Check constraints.
  320. assert hidden_states.shape[0] == gating_output.shape[0], (
  321. "Number of tokens mismatch")
  322. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
  323. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
  324. assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
  325. assert hidden_states.dtype in [
  326. torch.float32, torch.float16, torch.bfloat16
  327. ]
  328. M, _ = hidden_states.shape
  329. E, N, _ = w1.shape
  330. topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
  331. if override_config:
  332. config = override_config
  333. else:
  334. # First try to load optimal config from the file
  335. configs = get_moe_configs(E, w2.shape[2])
  336. if configs:
  337. # If an optimal configuration map has been found, look up the
  338. # optimal config
  339. config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
  340. else:
  341. # Else use the default config
  342. config = {
  343. 'BLOCK_SIZE_M': 64,
  344. 'BLOCK_SIZE_N': 64,
  345. 'BLOCK_SIZE_K': 32,
  346. 'GROUP_SIZE_M': 8
  347. }
  348. if M <= E:
  349. config = {
  350. 'BLOCK_SIZE_M': 16,
  351. 'BLOCK_SIZE_N': 32,
  352. 'BLOCK_SIZE_K': 64,
  353. 'GROUP_SIZE_M': 1
  354. }
  355. intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
  356. device=hidden_states.device,
  357. dtype=hidden_states.dtype)
  358. intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
  359. device=hidden_states.device,
  360. dtype=hidden_states.dtype)
  361. intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
  362. device=hidden_states.device,
  363. dtype=hidden_states.dtype)
  364. sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
  365. topk_ids, config['BLOCK_SIZE_M'], E)
  366. invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
  367. topk_weights, topk_ids, sorted_token_ids,
  368. expert_ids, num_tokens_post_padded, False,
  369. topk_ids.shape[1], config)
  370. ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
  371. invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
  372. topk_weights, topk_ids, sorted_token_ids,
  373. expert_ids, num_tokens_post_padded, True, 1,
  374. config)
  375. if inplace:
  376. return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  377. dim=1,
  378. out=hidden_states)
  379. return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  380. dim=1)