fused_moe.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. """Fused MoE kernel."""
  2. import functools
  3. import json
  4. import os
  5. from typing import Any, Callable, Dict, Optional, Tuple
  6. import torch
  7. import triton
  8. import triton.language as tl
  9. from loguru import logger
  10. import aphrodite.common.envs as envs
  11. from aphrodite import _custom_ops as ops
  12. from aphrodite.platforms import current_platform
  13. @triton.jit
  14. def fused_moe_kernel(
  15. # Pointers to matrices
  16. a_ptr,
  17. b_ptr,
  18. c_ptr,
  19. a_scale_ptr,
  20. b_scale_ptr,
  21. topk_weights_ptr,
  22. sorted_token_ids_ptr,
  23. expert_ids_ptr,
  24. num_tokens_post_padded_ptr,
  25. # Matrix dimensions
  26. N,
  27. K,
  28. EM,
  29. num_valid_tokens,
  30. # The stride variables represent how much to increase the ptr by when
  31. # moving by 1 element in a particular dimension. E.g. `stride_am` is
  32. # how much to increase `a_ptr` by to get the element one row down
  33. # (A has M rows).
  34. stride_am,
  35. stride_ak,
  36. stride_be,
  37. stride_bk,
  38. stride_bn,
  39. stride_cm,
  40. stride_cn,
  41. stride_bse,
  42. stride_bsn,
  43. # Meta-parameters
  44. BLOCK_SIZE_M: tl.constexpr,
  45. BLOCK_SIZE_N: tl.constexpr,
  46. BLOCK_SIZE_K: tl.constexpr,
  47. GROUP_SIZE_M: tl.constexpr,
  48. MUL_ROUTED_WEIGHT: tl.constexpr,
  49. top_k: tl.constexpr,
  50. compute_type: tl.constexpr,
  51. use_fp8_w8a8: tl.constexpr,
  52. use_int8_w8a16: tl.constexpr):
  53. """
  54. Implements the fused computation for a Mixture of Experts (MOE) using
  55. token and expert matrices.
  56. Key Parameters:
  57. - A: The input tensor representing tokens with shape (*, K), where '*' can
  58. be any shape representing batches and K is the feature dimension of
  59. each token.
  60. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
  61. the number of experts, K is the input feature dimension, and N is
  62. the output feature dimension.
  63. - C: The output cache tensor with shape (M, topk, N), where M is the
  64. total number of tokens post padding, topk is the number of times
  65. each token is repeated, and N is the output feature dimension.
  66. - sorted_token_ids: A tensor containing the sorted indices of tokens,
  67. repeated topk times and arranged by the expert index they are
  68. assigned to.
  69. - expert_ids: A tensor containing the indices of the expert for each
  70. block. It determines which expert matrix from B should be used for
  71. each block in A.
  72. This kernel performs the multiplication of a token by its corresponding
  73. expert matrix as determined by `expert_ids`. The sorting of
  74. `sorted_token_ids` by expert index and padding ensures divisibility by
  75. BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
  76. multiplication across different blocks processed by the same expert.
  77. """
  78. # -----------------------------------------------------------
  79. # Map program ids `pid` to the block of C it should compute.
  80. # This is done in a grouped ordering to promote L2 data reuse.
  81. pid = tl.program_id(axis=0)
  82. num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
  83. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  84. num_pid_in_group = GROUP_SIZE_M * num_pid_n
  85. group_id = pid // num_pid_in_group
  86. first_pid_m = group_id * GROUP_SIZE_M
  87. group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  88. pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
  89. pid_n = (pid % num_pid_in_group) // group_size_m
  90. # ----------------------------------------------------------
  91. # Create pointers for the first blocks of A and B.
  92. # We will advance this pointer as we move in the K direction
  93. # and accumulate
  94. # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
  95. # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
  96. num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
  97. if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
  98. return
  99. offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  100. offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
  101. token_mask = offs_token < num_valid_tokens
  102. offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  103. offs_k = tl.arange(0, BLOCK_SIZE_K)
  104. a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
  105. offs_k[None, :] * stride_ak)
  106. off_experts = tl.load(expert_ids_ptr + pid_m)
  107. b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
  108. offs_bn[None, :] * stride_bn)
  109. if use_int8_w8a16:
  110. b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
  111. None, :] * stride_bsn
  112. b_scale = tl.load(b_scale_ptrs)
  113. if use_fp8_w8a8:
  114. a_scale = tl.load(a_scale_ptr)
  115. b_scale = tl.load(b_scale_ptr + off_experts)
  116. # -----------------------------------------------------------
  117. # Iterate to compute a block of the C matrix.
  118. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
  119. # of fp32 values for higher accuracy.
  120. # `accumulator` will be converted back to fp16 after the loop.
  121. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  122. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  123. # Load the next block of A and B, generate a mask by checking the
  124. # K dimension.
  125. a = tl.load(a_ptrs,
  126. mask=token_mask[:, None] &
  127. (offs_k[None, :] < K - k * BLOCK_SIZE_K),
  128. other=0.0)
  129. b = tl.load(b_ptrs,
  130. mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
  131. other=0.0)
  132. # We accumulate along the K dimension.
  133. if use_int8_w8a16:
  134. accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
  135. elif use_fp8_w8a8:
  136. accumulator = tl.dot(a, b, acc=accumulator)
  137. else:
  138. accumulator += tl.dot(a, b)
  139. # Advance the ptrs to the next K block.
  140. a_ptrs += BLOCK_SIZE_K * stride_ak
  141. b_ptrs += BLOCK_SIZE_K * stride_bk
  142. if MUL_ROUTED_WEIGHT:
  143. moe_weight = tl.load(topk_weights_ptr + offs_token,
  144. mask=token_mask,
  145. other=0)
  146. accumulator = accumulator * moe_weight[:, None]
  147. if use_int8_w8a16:
  148. accumulator = (accumulator * b_scale).to(compute_type)
  149. elif use_fp8_w8a8:
  150. accumulator = (accumulator * a_scale * b_scale).to(compute_type)
  151. else:
  152. accumulator = accumulator.to(compute_type)
  153. # -----------------------------------------------------------
  154. # Write back the block of the output
  155. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  156. c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
  157. None, :]
  158. c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
  159. tl.store(c_ptrs, accumulator, mask=c_mask)
  160. def moe_align_block_size(
  161. topk_ids: torch.Tensor, block_size: int,
  162. num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  163. """
  164. Aligns the token distribution across experts to be compatible with block
  165. size for matrix multiplication.
  166. Parameters:
  167. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
  168. top-k expert indices for each token.
  169. - block_size: The block size used in block matrix multiplication.
  170. - num_experts: The total number of experts.
  171. Returns:
  172. - sorted_token_ids: A tensor containing the sorted token indices according
  173. to their allocated expert.
  174. - expert_ids: A tensor indicating the assigned expert index for each block.
  175. - num_tokens_post_padded: The total number of tokens after padding,
  176. ensuring divisibility by block_size.
  177. This function pads the number of tokens that each expert needs to process
  178. so that it is divisible by block_size.
  179. Padding ensures that during block matrix multiplication, the dimensions
  180. align correctly.
  181. Example:
  182. Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
  183. block_size = 4, and num_experts = 4:
  184. - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
  185. with each expert needing to process 3 tokens.
  186. - As block_size is 4, we pad 1 token for each expert.
  187. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
  188. - Then append padding tokens [12, 12, 12, 12] for each block.
  189. - After sorting by expert index, we obtain token_ids
  190. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
  191. Tokens 12 are non-existent (padding) and are ignored in
  192. the subsequent matrix multiplication.
  193. - The padding ensures that the total number of tokens is now divisible
  194. by block_size for proper block matrix operations.
  195. """
  196. max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
  197. sorted_ids = torch.empty((max_num_tokens_padded, ),
  198. dtype=torch.int32,
  199. device=topk_ids.device)
  200. sorted_ids.fill_(topk_ids.numel())
  201. max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
  202. expert_ids = torch.empty((max_num_m_blocks, ),
  203. dtype=torch.int32,
  204. device=topk_ids.device)
  205. num_tokens_post_pad = torch.empty((1),
  206. dtype=torch.int32,
  207. device=topk_ids.device)
  208. ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
  209. expert_ids, num_tokens_post_pad)
  210. return sorted_ids, expert_ids, num_tokens_post_pad
  211. def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
  212. A_scale: Optional[torch.Tensor],
  213. B_scale: Optional[torch.Tensor],
  214. topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  215. sorted_token_ids: torch.Tensor,
  216. expert_ids: torch.Tensor,
  217. num_tokens_post_padded: torch.Tensor,
  218. mul_routed_weight: bool, top_k: int,
  219. config: Dict[str, Any], compute_type: tl.dtype,
  220. use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
  221. assert topk_weights.stride(1) == 1
  222. assert sorted_token_ids.stride(0) == 1
  223. if use_fp8_w8a8:
  224. A, A_scale = ops.scaled_fp8_quant(A, A_scale)
  225. assert B_scale is not None
  226. elif use_int8_w8a16:
  227. assert B_scale is not None
  228. else:
  229. assert A_scale is None
  230. assert B_scale is None
  231. grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
  232. 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
  233. fused_moe_kernel[grid](
  234. A,
  235. B,
  236. C,
  237. A_scale,
  238. B_scale,
  239. topk_weights,
  240. sorted_token_ids,
  241. expert_ids,
  242. num_tokens_post_padded,
  243. B.shape[1],
  244. B.shape[2],
  245. sorted_token_ids.shape[0],
  246. topk_ids.numel(),
  247. A.stride(0),
  248. A.stride(1),
  249. B.stride(0),
  250. B.stride(2),
  251. B.stride(1),
  252. C.stride(1),
  253. C.stride(2),
  254. B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
  255. B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
  256. MUL_ROUTED_WEIGHT=mul_routed_weight,
  257. top_k=top_k,
  258. compute_type=compute_type,
  259. use_fp8_w8a8=use_fp8_w8a8,
  260. use_int8_w8a16=use_int8_w8a16,
  261. **config,
  262. )
  263. def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
  264. device_name = current_platform.get_device_name().replace(" ", "_")
  265. dtype_selector = "" if not dtype else f",dtype={dtype}"
  266. return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
  267. @functools.lru_cache
  268. def get_moe_configs(E: int, N: int,
  269. dtype: Optional[str]) -> Optional[Dict[int, Any]]:
  270. """
  271. Return optimized configurations for the fused MoE kernel.
  272. The return value will be a dictionary that maps an irregular grid of
  273. batch sizes to configurations of the fused_moe kernel. To evaluate the
  274. kernel on a given batch size bs, the closest batch size in the grid should
  275. be picked and the associated configuration chosen to invoke the kernel.
  276. """
  277. # First look up if an optimized configuration is available in the configs
  278. # directory
  279. json_file_name = get_config_file_name(E, N, dtype)
  280. config_file_path = os.path.join(
  281. os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
  282. if os.path.exists(config_file_path):
  283. with open(config_file_path) as f:
  284. logger.info(
  285. f"Using configuration from {config_file_path} for MoE layer.")
  286. # If a configuration has been found, return it
  287. return {int(key): val for key, val in json.load(f).items()}
  288. # If no optimized configuration is available, we will use the default
  289. # configuration
  290. return None
  291. def get_default_config(M: int, E: int, N: int, K: int, topk: int,
  292. dtype: Optional[str],
  293. is_marlin: bool) -> Dict[str, int]:
  294. config = {
  295. 'BLOCK_SIZE_M': 64,
  296. 'BLOCK_SIZE_N': 64,
  297. 'BLOCK_SIZE_K': 32,
  298. 'GROUP_SIZE_M': 8
  299. }
  300. if M <= E or (is_marlin and M <= 32):
  301. config = {
  302. 'BLOCK_SIZE_M': 16,
  303. 'BLOCK_SIZE_N': 32,
  304. 'BLOCK_SIZE_K': 64,
  305. 'GROUP_SIZE_M': 1
  306. }
  307. return config
  308. def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
  309. w2_shape: Tuple[int, ...],
  310. top_k: int,
  311. dtype: Optional[str],
  312. M: int,
  313. override_config: Optional[Dict[str,
  314. Any]] = None,
  315. is_marlin: bool = False):
  316. if override_config:
  317. config = override_config
  318. else:
  319. # First try to load optimal config from the file
  320. E, _, N = w2_shape
  321. configs = get_moe_configs(E, N, dtype)
  322. if configs:
  323. # If an optimal configuration map has been found, look up the
  324. # optimal config
  325. config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
  326. else:
  327. # Else use the default config
  328. config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
  329. is_marlin)
  330. return config
  331. def fused_topk(
  332. hidden_states: torch.Tensor,
  333. gating_output: torch.Tensor,
  334. topk: int,
  335. renormalize: bool,
  336. ):
  337. assert hidden_states.shape[0] == gating_output.shape[0], (
  338. "Number of tokens mismatch")
  339. M, _ = hidden_states.shape
  340. topk_weights = torch.empty(M,
  341. topk,
  342. dtype=torch.float32,
  343. device=hidden_states.device)
  344. topk_ids = torch.empty(M,
  345. topk,
  346. dtype=torch.int32,
  347. device=hidden_states.device)
  348. token_expert_indicies = torch.empty(M,
  349. topk,
  350. dtype=torch.int32,
  351. device=hidden_states.device)
  352. ops.topk_softmax(
  353. topk_weights,
  354. topk_ids,
  355. token_expert_indicies,
  356. gating_output.float(), # TODO: Optimize this.
  357. )
  358. del token_expert_indicies # Not used. Will be used in the future.
  359. if renormalize:
  360. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  361. return topk_weights, topk_ids
  362. # This is used by the Deepseek-V2 model
  363. def grouped_topk(hidden_states: torch.Tensor,
  364. gating_output: torch.Tensor,
  365. topk: int,
  366. renormalize: bool,
  367. num_expert_group: int = 0,
  368. topk_group: int = 0):
  369. assert hidden_states.shape[0] == gating_output.shape[0], (
  370. "Number of tokens mismatch")
  371. scores = torch.softmax(gating_output, dim=-1)
  372. num_token = scores.shape[0]
  373. group_scores = scores.view(num_token, num_expert_group,
  374. -1).max(dim=-1).values # [n, n_group]
  375. group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
  376. sorted=False)[1] # [n, top_k_group]
  377. group_mask = torch.zeros_like(group_scores) # [n, n_group]
  378. group_mask.scatter_(1, group_idx, 1) # [n, n_group]
  379. score_mask = group_mask.unsqueeze(-1).expand(
  380. num_token, num_expert_group,
  381. scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
  382. tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
  383. topk_weights, topk_ids = torch.topk(tmp_scores,
  384. k=topk,
  385. dim=-1,
  386. sorted=False)
  387. if renormalize:
  388. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  389. return topk_weights, topk_ids
  390. def fused_marlin_moe(hidden_states: torch.Tensor,
  391. w1: torch.Tensor,
  392. w2: torch.Tensor,
  393. gating_output: torch.Tensor,
  394. g_idx1: torch.Tensor,
  395. g_idx2: torch.Tensor,
  396. rand_perm1: torch.Tensor,
  397. rand_perm2: torch.Tensor,
  398. topk: int,
  399. custom_routing_function: Optional[Callable] = None,
  400. renormalize: bool = True,
  401. override_config: Optional[Dict[str, Any]] = None,
  402. use_fp8: bool = False,
  403. w1_scale: Optional[torch.Tensor] = None,
  404. w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
  405. """
  406. This function computes a Mixture of Experts (MoE) layer using two sets of
  407. weights, w1 and w2, and top-k gating mechanism.
  408. Parameters:
  409. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
  410. - w1 (torch.Tensor): The first set of expert weights.
  411. - w2 (torch.Tensor): The second set of expert weights.
  412. - gating_output (torch.Tensor): The output of the gating operation
  413. (before softmax).
  414. - topk (int): The number of top-k experts to select.
  415. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  416. - inplace (bool): If True, perform the operation in-place.
  417. Defaults to False.
  418. - override_config (Optional[Dict[str, Any]]): Optional override
  419. for the kernel configuration.
  420. - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
  421. products for w1 and w2. Defaults to False.
  422. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
  423. w1.
  424. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
  425. w2.
  426. Returns:
  427. - torch.Tensor: The output tensor after applying the MoE layer.
  428. """
  429. # Check constraints.
  430. assert hidden_states.shape[0] == gating_output.shape[0], (
  431. "Number of tokens mismatch")
  432. assert hidden_states.shape[
  433. 1] == w1.shape[1] * 16, "Hidden size mismatch w1"
  434. assert hidden_states.shape[
  435. 1] == w2.shape[2] // 2, "Hidden size mismatch w2"
  436. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
  437. assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
  438. assert w1.is_contiguous(), "Expert weights1 must be contiguous"
  439. assert w2.is_contiguous(), "Expert weights2 must be contiguous"
  440. assert hidden_states.dtype in [
  441. torch.float32, torch.float16, torch.bfloat16
  442. ]
  443. #TODO fp8 is not implemented yet
  444. assert not use_fp8
  445. M, K = hidden_states.shape
  446. E = w1.shape[0]
  447. N = w2.shape[1] * 16
  448. if custom_routing_function is None:
  449. topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
  450. renormalize)
  451. else:
  452. topk_weights, topk_ids = custom_routing_function(
  453. hidden_states, gating_output, topk, renormalize)
  454. get_config_func = functools.partial(try_get_optimal_moe_config,
  455. w1.shape,
  456. w2.shape,
  457. topk_ids.shape[1],
  458. "float8" if use_fp8 else None,
  459. override_config=override_config,
  460. is_marlin=True)
  461. config = get_config_func(M)
  462. block_size_m = config['BLOCK_SIZE_M']
  463. sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
  464. max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
  465. workspace = torch.zeros(max_workspace_size,
  466. dtype=torch.int,
  467. device="cuda",
  468. requires_grad=False)
  469. intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
  470. device=hidden_states.device,
  471. dtype=hidden_states.dtype)
  472. intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
  473. hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
  474. g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
  475. block_size_m, True, False)
  476. ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
  477. intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
  478. intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
  479. w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
  480. block_size_m, False, True)
  481. return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  482. dim=1)
  483. def get_config_dtype_str(dtype: torch.dtype,
  484. use_int8_w8a16: Optional[bool] = False,
  485. use_fp8_w8a8: Optional[bool] = False):
  486. if use_fp8_w8a8:
  487. return "fp8_w8a8"
  488. elif use_int8_w8a16:
  489. return "int8_w8a16"
  490. elif dtype == torch.float:
  491. # avoiding cases where kernel fails when float32 MoE
  492. # use fp16/bfloat16 configs
  493. return "float32"
  494. return None
  495. def fused_experts(hidden_states: torch.Tensor,
  496. w1: torch.Tensor,
  497. w2: torch.Tensor,
  498. topk_weights: torch.Tensor,
  499. topk_ids: torch.Tensor,
  500. inplace: bool = False,
  501. override_config: Optional[Dict[str, Any]] = None,
  502. use_fp8_w8a8: bool = False,
  503. use_int8_w8a16: bool = False,
  504. w1_scale: Optional[torch.Tensor] = None,
  505. w2_scale: Optional[torch.Tensor] = None,
  506. a1_scale: Optional[torch.Tensor] = None,
  507. a2_scale: Optional[torch.Tensor] = None):
  508. # Check constraints.
  509. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
  510. assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
  511. assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
  512. assert w1.is_contiguous(), "Expert weights1 must be contiguous"
  513. assert w2.is_contiguous(), "Expert weights2 must be contiguous"
  514. assert hidden_states.dtype in [
  515. torch.float32, torch.float16, torch.bfloat16
  516. ]
  517. num_tokens, _ = hidden_states.shape
  518. E, N, _ = w1.shape
  519. CHUNK_SIZE = envs.APHRODITE_FUSED_MOE_CHUNK_SIZE
  520. M = min(num_tokens, CHUNK_SIZE)
  521. config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
  522. use_int8_w8a16=use_int8_w8a16,
  523. dtype=hidden_states.dtype)
  524. get_config_func = functools.partial(
  525. try_get_optimal_moe_config,
  526. w1.shape,
  527. w2.shape,
  528. topk_ids.shape[1],
  529. config_dtype,
  530. override_config=override_config,
  531. )
  532. config = get_config_func(M)
  533. intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
  534. device=hidden_states.device,
  535. dtype=hidden_states.dtype)
  536. intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
  537. device=hidden_states.device,
  538. dtype=hidden_states.dtype)
  539. intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
  540. device=hidden_states.device,
  541. dtype=hidden_states.dtype)
  542. compute_type = (tl.bfloat16
  543. if hidden_states.dtype == torch.bfloat16 else tl.float16)
  544. if inplace:
  545. out_hidden_states = hidden_states
  546. else:
  547. out_hidden_states = torch.empty_like(hidden_states)
  548. for chunk in range((num_tokens // CHUNK_SIZE) + 1):
  549. begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
  550. min((chunk + 1) * CHUNK_SIZE,
  551. num_tokens))
  552. curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
  553. tokens_in_chunk, _ = curr_hidden_states.shape
  554. if tokens_in_chunk == 0:
  555. break
  556. if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
  557. # Adjust the intermediate cache size and config for the last
  558. # chunk. Note that in most cases we only have one chunk
  559. # so the cache size and config are already set correctly and
  560. # do not need to be adjusted.
  561. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
  562. intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
  563. intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
  564. config = get_config_func(tokens_in_chunk)
  565. curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
  566. curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
  567. sorted_token_ids, expert_ids, num_tokens_post_padded = (
  568. moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
  569. invoke_fused_moe_kernel(curr_hidden_states,
  570. w1,
  571. intermediate_cache1,
  572. a1_scale,
  573. w1_scale,
  574. curr_topk_weights,
  575. curr_topk_ids,
  576. sorted_token_ids,
  577. expert_ids,
  578. num_tokens_post_padded,
  579. False,
  580. topk_ids.shape[1],
  581. config,
  582. compute_type=compute_type,
  583. use_fp8_w8a8=use_fp8_w8a8,
  584. use_int8_w8a16=use_int8_w8a16)
  585. ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
  586. invoke_fused_moe_kernel(intermediate_cache2,
  587. w2,
  588. intermediate_cache3,
  589. a2_scale,
  590. w2_scale,
  591. curr_topk_weights,
  592. curr_topk_ids,
  593. sorted_token_ids,
  594. expert_ids,
  595. num_tokens_post_padded,
  596. True,
  597. 1,
  598. config,
  599. compute_type=compute_type,
  600. use_fp8_w8a8=use_fp8_w8a8,
  601. use_int8_w8a16=use_int8_w8a16)
  602. torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  603. dim=1,
  604. out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
  605. return out_hidden_states
  606. def fused_moe(
  607. hidden_states: torch.Tensor,
  608. w1: torch.Tensor,
  609. w2: torch.Tensor,
  610. gating_output: torch.Tensor,
  611. topk: int,
  612. renormalize: bool,
  613. inplace: bool = False,
  614. override_config: Optional[Dict[str, Any]] = None,
  615. use_grouped_topk: bool = False,
  616. num_expert_group: Optional[int] = None,
  617. topk_group: Optional[int] = None,
  618. custom_routing_function: Optional[Callable] = None,
  619. use_fp8_w8a8: bool = False,
  620. use_int8_w8a16: bool = False,
  621. w1_scale: Optional[torch.Tensor] = None,
  622. w2_scale: Optional[torch.Tensor] = None,
  623. a1_scale: Optional[torch.Tensor] = None,
  624. a2_scale: Optional[torch.Tensor] = None,
  625. ) -> torch.Tensor:
  626. """
  627. This function computes a Mixture of Experts (MoE) layer using two sets of
  628. weights, w1 and w2, and top-k gating mechanism.
  629. Parameters:
  630. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
  631. - w1 (torch.Tensor): The first set of expert weights.
  632. - w2 (torch.Tensor): The second set of expert weights.
  633. - gating_output (torch.Tensor): The output of the gating operation
  634. (before softmax).
  635. - topk (int): The number of top-k experts to select.
  636. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  637. - inplace (bool): If True, perform the operation in-place.
  638. Defaults to False.
  639. - override_config (Optional[Dict[str, Any]]): Optional override
  640. for the kernel configuration.
  641. - num_expert_group: Optional[int]: additional parameter for grouped_topk
  642. - topk_group: Optional[int]: additional parameter for grouped_topk
  643. - use_grouped_topk: If True, use grouped_topk instead of fused_topk
  644. note: Deepseekv2 model uses grouped_topk
  645. - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
  646. products for w1 and w2. Defaults to False.
  647. - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
  648. products for w1 and w2. Defaults to False.
  649. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
  650. w1.
  651. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
  652. w2.
  653. Returns:
  654. - torch.Tensor: The output tensor after applying the MoE layer.
  655. """
  656. # Check constraints.
  657. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
  658. if use_grouped_topk:
  659. assert num_expert_group is not None and topk_group is not None
  660. topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
  661. topk, renormalize,
  662. num_expert_group, topk_group)
  663. elif custom_routing_function is None:
  664. topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
  665. renormalize)
  666. else:
  667. topk_weights, topk_ids = custom_routing_function(
  668. hidden_states, gating_output, topk, renormalize)
  669. return fused_experts(hidden_states,
  670. w1,
  671. w2,
  672. topk_weights,
  673. topk_ids,
  674. inplace=inplace,
  675. override_config=override_config,
  676. use_fp8_w8a8=use_fp8_w8a8,
  677. use_int8_w8a16=use_int8_w8a16,
  678. w1_scale=w1_scale,
  679. w2_scale=w2_scale,
  680. a1_scale=a1_scale,
  681. a2_scale=a2_scale)