fused_moe.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. """Fused MoE kernel."""
  2. import functools
  3. import json
  4. import os
  5. from typing import Any, Dict, Optional, Tuple
  6. import torch
  7. import triton
  8. import triton.language as tl
  9. from loguru import logger
  10. from aphrodite._C import ops
  11. from aphrodite.common.utils import is_hip
  12. @triton.jit
  13. def fused_moe_kernel(
  14. # Pointers to matrices
  15. a_ptr,
  16. b_ptr,
  17. c_ptr,
  18. a_scale_ptr,
  19. b_scale_ptr,
  20. topk_weights_ptr,
  21. sorted_token_ids_ptr,
  22. expert_ids_ptr,
  23. num_tokens_post_padded_ptr,
  24. # Matrix dimensions
  25. N,
  26. K,
  27. EM,
  28. num_valid_tokens,
  29. # The stride variables represent how much to increase the ptr by when
  30. # moving by 1 element in a particular dimension. E.g. `stride_am` is
  31. # how much to increase `a_ptr` by to get the element one row down
  32. # (A has M rows).
  33. stride_am,
  34. stride_ak,
  35. stride_be,
  36. stride_bk,
  37. stride_bn,
  38. stride_cm,
  39. stride_cn,
  40. # Meta-parameters
  41. BLOCK_SIZE_M: tl.constexpr,
  42. BLOCK_SIZE_N: tl.constexpr,
  43. BLOCK_SIZE_K: tl.constexpr,
  44. GROUP_SIZE_M: tl.constexpr,
  45. MUL_ROUTED_WEIGHT: tl.constexpr,
  46. top_k: tl.constexpr,
  47. compute_type: tl.constexpr,
  48. use_fp8: tl.constexpr,
  49. ):
  50. """
  51. Implements the fused computation for a Mixture of Experts (MOE) using
  52. token and expert matrices.
  53. Key Parameters:
  54. - A: The input tensor representing tokens with shape (*, K), where '*' can
  55. be any shape representing batches and K is the feature dimension of
  56. each token.
  57. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
  58. the number of experts, K is the input feature dimension, and N is
  59. the output feature dimension.
  60. - C: The output cache tensor with shape (M, topk, N), where M is the
  61. total number of tokens post padding, topk is the number of times
  62. each token is repeated, and N is the output feature dimension.
  63. - sorted_token_ids: A tensor containing the sorted indices of tokens,
  64. repeated topk times and arranged by the expert index they are
  65. assigned to.
  66. - expert_ids: A tensor containing the indices of the expert for each
  67. block. It determines which expert matrix from B should be used for
  68. each block in A.
  69. This kernel performs the multiplication of a token by its corresponding
  70. expert matrix as determined by `expert_ids`. The sorting of
  71. `sorted_token_ids` by expert index and padding ensures divisibility by
  72. BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
  73. multiplication across different blocks processed by the same expert.
  74. """
  75. # -----------------------------------------------------------
  76. # Map program ids `pid` to the block of C it should compute.
  77. # This is done in a grouped ordering to promote L2 data reuse.
  78. pid = tl.program_id(axis=0)
  79. num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
  80. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  81. num_pid_in_group = GROUP_SIZE_M * num_pid_n
  82. group_id = pid // num_pid_in_group
  83. first_pid_m = group_id * GROUP_SIZE_M
  84. group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  85. pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
  86. pid_n = (pid % num_pid_in_group) // group_size_m
  87. # ----------------------------------------------------------
  88. # Create pointers for the first blocks of A and B.
  89. # We will advance this pointer as we move in the K direction
  90. # and accumulate
  91. # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
  92. # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
  93. num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
  94. if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
  95. return
  96. offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  97. offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
  98. token_mask = offs_token < num_valid_tokens
  99. offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  100. offs_k = tl.arange(0, BLOCK_SIZE_K)
  101. a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
  102. offs_k[None, :] * stride_ak)
  103. off_experts = tl.load(expert_ids_ptr + pid_m)
  104. b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
  105. offs_bn[None, :] * stride_bn)
  106. if use_fp8:
  107. a_scale = tl.load(a_scale_ptr)
  108. b_scale = tl.load(b_scale_ptr + off_experts)
  109. # -----------------------------------------------------------
  110. # Iterate to compute a block of the C matrix.
  111. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
  112. # of fp32 values for higher accuracy.
  113. # `accumulator` will be converted back to fp16 after the loop.
  114. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  115. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  116. # Load the next block of A and B, generate a mask by checking the
  117. # K dimension.
  118. a = tl.load(a_ptrs,
  119. mask=token_mask[:, None] &
  120. (offs_k[None, :] < K - k * BLOCK_SIZE_K),
  121. other=0.0)
  122. b = tl.load(b_ptrs,
  123. mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
  124. other=0.0)
  125. # We accumulate along the K dimension.
  126. if use_fp8:
  127. accumulator = tl.dot(a, b, acc=accumulator)
  128. else:
  129. accumulator += tl.dot(a, b)
  130. # Advance the ptrs to the next K block.
  131. a_ptrs += BLOCK_SIZE_K * stride_ak
  132. b_ptrs += BLOCK_SIZE_K * stride_bk
  133. if MUL_ROUTED_WEIGHT:
  134. moe_weight = tl.load(topk_weights_ptr + offs_token,
  135. mask=token_mask,
  136. other=0)
  137. accumulator = accumulator * moe_weight[:, None]
  138. if use_fp8:
  139. accumulator = (accumulator * a_scale * b_scale).to(compute_type)
  140. else:
  141. accumulator = accumulator.to(compute_type)
  142. # -----------------------------------------------------------
  143. # Write back the block of the output
  144. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  145. c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
  146. None, :]
  147. c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
  148. tl.store(c_ptrs, accumulator, mask=c_mask)
  149. def moe_align_block_size(
  150. topk_ids: torch.Tensor, block_size: int,
  151. num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  152. """
  153. Aligns the token distribution across experts to be compatible with block
  154. size for matrix multiplication.
  155. Parameters:
  156. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
  157. top-k expert indices for each token.
  158. - block_size: The block size used in block matrix multiplication.
  159. - num_experts: The total number of experts.
  160. Returns:
  161. - sorted_token_ids: A tensor containing the sorted token indices according
  162. to their allocated expert.
  163. - expert_ids: A tensor indicating the assigned expert index for each block.
  164. - num_tokens_post_padded: The total number of tokens after padding,
  165. ensuring divisibility by block_size.
  166. This function pads the number of tokens that each expert needs to process
  167. so that it is divisible by block_size.
  168. Padding ensures that during block matrix multiplication, the dimensions
  169. align correctly.
  170. Example:
  171. Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
  172. block_size = 4, and num_experts = 4:
  173. - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
  174. with each expert needing to process 3 tokens.
  175. - As block_size is 4, we pad 1 token for each expert.
  176. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
  177. - Then append padding tokens [12, 12, 12, 12] for each block.
  178. - After sorting by expert index, we obtain token_ids
  179. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
  180. Tokens 12 are non-existent (padding) and are ignored in
  181. the subsequent matrix multiplication.
  182. - The padding ensures that the total number of tokens is now divisible
  183. by block_size for proper block matrix operations.
  184. """
  185. sorted_ids = torch.empty(
  186. (topk_ids.numel() + num_experts * (block_size - 1), ),
  187. dtype=torch.int32,
  188. device=topk_ids.device)
  189. expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
  190. dtype=torch.int32,
  191. device=topk_ids.device)
  192. sorted_ids.fill_(topk_ids.numel())
  193. num_tokens_post_pad = torch.empty((1),
  194. dtype=torch.int32,
  195. device=topk_ids.device)
  196. ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
  197. expert_ids, num_tokens_post_pad)
  198. return sorted_ids, expert_ids, num_tokens_post_pad
  199. def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
  200. A_scale: Optional[torch.Tensor],
  201. B_scale: Optional[torch.Tensor],
  202. topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  203. sorted_token_ids: torch.Tensor,
  204. expert_ids: torch.Tensor,
  205. num_tokens_post_padded: torch.Tensor,
  206. mul_routed_weight: bool, top_k: int,
  207. config: Dict[str, Any], compute_type: tl.dtype,
  208. use_fp8: bool) -> None:
  209. assert topk_weights.stride(1) == 1
  210. assert sorted_token_ids.stride(0) == 1
  211. if not use_fp8:
  212. assert A_scale is None
  213. assert B_scale is None
  214. else:
  215. A, A_scale = ops.scaled_fp8_quant(A, A_scale)
  216. assert B_scale is not None
  217. grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
  218. 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
  219. fused_moe_kernel[grid](
  220. A,
  221. B,
  222. C,
  223. A_scale,
  224. B_scale,
  225. topk_weights,
  226. sorted_token_ids,
  227. expert_ids,
  228. num_tokens_post_padded,
  229. B.shape[1],
  230. B.shape[2],
  231. sorted_token_ids.shape[0],
  232. topk_ids.numel(),
  233. A.stride(0),
  234. A.stride(1),
  235. B.stride(0),
  236. B.stride(2),
  237. B.stride(1),
  238. C.stride(1),
  239. C.stride(2),
  240. MUL_ROUTED_WEIGHT=mul_routed_weight,
  241. top_k=top_k,
  242. compute_type=compute_type,
  243. use_fp8=use_fp8,
  244. **config,
  245. )
  246. def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
  247. device_name = torch.cuda.get_device_name().replace(" ", "_")
  248. dtype_selector = "" if not dtype else f",dtype={dtype}"
  249. return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
  250. @functools.lru_cache
  251. def get_moe_configs(E: int, N: int,
  252. dtype: Optional[str]) -> Optional[Dict[int, Any]]:
  253. """
  254. Return optimized configurations for the fused MoE kernel.
  255. The return value will be a dictionary that maps an irregular grid of
  256. batch sizes to configurations of the fused_moe kernel. To evaluate the
  257. kernel on a given batch size bs, the closest batch size in the grid should
  258. be picked and the associated configuration chosen to invoke the kernel.
  259. """
  260. # First look up if an optimized configuration is available in the configs
  261. # directory
  262. json_file_name = get_config_file_name(E, N, dtype)
  263. config_file_path = os.path.join(
  264. os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
  265. if os.path.exists(config_file_path):
  266. with open(config_file_path) as f:
  267. logger.info(
  268. f"Using configuration from {config_file_path} for MoE layer.")
  269. # If a configuration has been found, return it
  270. return {int(key): val for key, val in json.load(f).items()}
  271. # If no optimized configuration is available, we will use the default
  272. # configuration
  273. return None
  274. def fused_moe(
  275. hidden_states: torch.Tensor,
  276. w1: torch.Tensor,
  277. w2: torch.Tensor,
  278. gating_output: torch.Tensor,
  279. topk: int,
  280. renormalize: bool,
  281. inplace: bool = False,
  282. override_config: Optional[Dict[str, Any]] = None,
  283. use_fp8: bool = False,
  284. w1_scale: Optional[torch.Tensor] = None,
  285. w2_scale: Optional[torch.Tensor] = None,
  286. a1_scale: Optional[torch.Tensor] = None,
  287. a2_scale: Optional[torch.Tensor] = None,
  288. ) -> torch.Tensor:
  289. """
  290. This function computes a Mixture of Experts (MoE) layer using two sets of
  291. weights, w1 and w2, and top-k gating mechanism.
  292. Parameters:
  293. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
  294. - w1 (torch.Tensor): The first set of expert weights.
  295. - w2 (torch.Tensor): The second set of expert weights.
  296. - gating_output (torch.Tensor): The output of the gating operation
  297. (before softmax).
  298. - topk (int): The number of top-k experts to select.
  299. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  300. - inplace (bool): If True, perform the operation in-place.
  301. Defaults to False.
  302. - override_config (Optional[Dict[str, Any]]): Optional override
  303. for the kernel configuration.
  304. - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
  305. products for w1 and w2. Defaults to False.
  306. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
  307. w1.
  308. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
  309. w2.
  310. Returns:
  311. - torch.Tensor: The output tensor after applying the MoE layer.
  312. """
  313. # Check constraints.
  314. assert hidden_states.shape[0] == gating_output.shape[0], (
  315. "Number of tokens mismatch")
  316. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
  317. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
  318. assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
  319. assert w1.is_contiguous(), "Expert weights1 must be contiguous"
  320. assert w2.is_contiguous(), "Expert weights2 must be contiguous"
  321. assert hidden_states.dtype in [
  322. torch.float32, torch.float16, torch.bfloat16
  323. ]
  324. M, _ = hidden_states.shape
  325. E, N, _ = w1.shape
  326. if is_hip():
  327. # The MoE kernels are not yet supported on ROCm.
  328. routing_weights = torch.softmax(gating_output,
  329. dim=-1,
  330. dtype=torch.float32)
  331. topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
  332. else:
  333. import aphrodite._moe_C as moe_kernels
  334. topk_weights = torch.empty(M,
  335. topk,
  336. dtype=torch.float32,
  337. device=hidden_states.device)
  338. topk_ids = torch.empty(M,
  339. topk,
  340. dtype=torch.int32,
  341. device=hidden_states.device)
  342. token_expert_indicies = torch.empty(M,
  343. topk,
  344. dtype=torch.int32,
  345. device=hidden_states.device)
  346. moe_kernels.topk_softmax(
  347. topk_weights,
  348. topk_ids,
  349. token_expert_indicies,
  350. gating_output.float(), # TODO(woosuk): Optimize this.
  351. )
  352. del token_expert_indicies # Not used. Will be used in the future.
  353. if renormalize:
  354. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  355. if override_config:
  356. config = override_config
  357. else:
  358. # First try to load optimal config from the file
  359. configs = get_moe_configs(E, w2.shape[2],
  360. "float8" if use_fp8 else None)
  361. if configs:
  362. # If an optimal configuration map has been found, look up the
  363. # optimal config
  364. config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
  365. else:
  366. # Else use the default config
  367. config = {
  368. 'BLOCK_SIZE_M': 64,
  369. 'BLOCK_SIZE_N': 64,
  370. 'BLOCK_SIZE_K': 32,
  371. 'GROUP_SIZE_M': 8
  372. }
  373. if M <= E:
  374. config = {
  375. 'BLOCK_SIZE_M': 16,
  376. 'BLOCK_SIZE_N': 32,
  377. 'BLOCK_SIZE_K': 64,
  378. 'GROUP_SIZE_M': 1
  379. }
  380. intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
  381. device=hidden_states.device,
  382. dtype=hidden_states.dtype)
  383. intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
  384. device=hidden_states.device,
  385. dtype=hidden_states.dtype)
  386. intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
  387. device=hidden_states.device,
  388. dtype=hidden_states.dtype)
  389. sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
  390. topk_ids, config['BLOCK_SIZE_M'], E)
  391. invoke_fused_moe_kernel(hidden_states,
  392. w1,
  393. intermediate_cache1,
  394. a1_scale,
  395. w1_scale,
  396. topk_weights,
  397. topk_ids,
  398. sorted_token_ids,
  399. expert_ids,
  400. num_tokens_post_padded,
  401. False,
  402. topk_ids.shape[1],
  403. config,
  404. compute_type=tl.float16,
  405. use_fp8=use_fp8)
  406. ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
  407. invoke_fused_moe_kernel(intermediate_cache2,
  408. w2,
  409. intermediate_cache3,
  410. a2_scale,
  411. w2_scale,
  412. topk_weights,
  413. topk_ids,
  414. sorted_token_ids,
  415. expert_ids,
  416. num_tokens_post_padded,
  417. True,
  418. 1,
  419. config,
  420. compute_type=tl.float16,
  421. use_fp8=use_fp8)
  422. if inplace:
  423. return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  424. dim=1,
  425. out=hidden_states)
  426. return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  427. dim=1)