fused_moe.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. """Fused MoE kernel."""
  2. import functools
  3. import json
  4. import os
  5. from typing import Any, Dict, Optional, Tuple
  6. import torch
  7. import triton
  8. import triton.language as tl
  9. from loguru import logger
  10. from aphrodite import _custom_ops as ops
  11. APHRODITE_FUSED_MOE_CHUNK_SIZE = int(
  12. os.getenv("APHRODITE_FUSED_MOE_CHUNK_SIZE", "65536"))
  13. @triton.jit
  14. def fused_moe_kernel(
  15. # Pointers to matrices
  16. a_ptr,
  17. b_ptr,
  18. c_ptr,
  19. a_scale_ptr,
  20. b_scale_ptr,
  21. topk_weights_ptr,
  22. sorted_token_ids_ptr,
  23. expert_ids_ptr,
  24. num_tokens_post_padded_ptr,
  25. # Matrix dimensions
  26. N,
  27. K,
  28. EM,
  29. num_valid_tokens,
  30. # The stride variables represent how much to increase the ptr by when
  31. # moving by 1 element in a particular dimension. E.g. `stride_am` is
  32. # how much to increase `a_ptr` by to get the element one row down
  33. # (A has M rows).
  34. stride_am,
  35. stride_ak,
  36. stride_be,
  37. stride_bk,
  38. stride_bn,
  39. stride_cm,
  40. stride_cn,
  41. # Meta-parameters
  42. BLOCK_SIZE_M: tl.constexpr,
  43. BLOCK_SIZE_N: tl.constexpr,
  44. BLOCK_SIZE_K: tl.constexpr,
  45. GROUP_SIZE_M: tl.constexpr,
  46. MUL_ROUTED_WEIGHT: tl.constexpr,
  47. top_k: tl.constexpr,
  48. compute_type: tl.constexpr,
  49. use_fp8: tl.constexpr,
  50. ):
  51. """
  52. Implements the fused computation for a Mixture of Experts (MOE) using
  53. token and expert matrices.
  54. Key Parameters:
  55. - A: The input tensor representing tokens with shape (*, K), where '*' can
  56. be any shape representing batches and K is the feature dimension of
  57. each token.
  58. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
  59. the number of experts, K is the input feature dimension, and N is
  60. the output feature dimension.
  61. - C: The output cache tensor with shape (M, topk, N), where M is the
  62. total number of tokens post padding, topk is the number of times
  63. each token is repeated, and N is the output feature dimension.
  64. - sorted_token_ids: A tensor containing the sorted indices of tokens,
  65. repeated topk times and arranged by the expert index they are
  66. assigned to.
  67. - expert_ids: A tensor containing the indices of the expert for each
  68. block. It determines which expert matrix from B should be used for
  69. each block in A.
  70. This kernel performs the multiplication of a token by its corresponding
  71. expert matrix as determined by `expert_ids`. The sorting of
  72. `sorted_token_ids` by expert index and padding ensures divisibility by
  73. BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
  74. multiplication across different blocks processed by the same expert.
  75. """
  76. # -----------------------------------------------------------
  77. # Map program ids `pid` to the block of C it should compute.
  78. # This is done in a grouped ordering to promote L2 data reuse.
  79. pid = tl.program_id(axis=0)
  80. num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
  81. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  82. num_pid_in_group = GROUP_SIZE_M * num_pid_n
  83. group_id = pid // num_pid_in_group
  84. first_pid_m = group_id * GROUP_SIZE_M
  85. group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  86. pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
  87. pid_n = (pid % num_pid_in_group) // group_size_m
  88. # ----------------------------------------------------------
  89. # Create pointers for the first blocks of A and B.
  90. # We will advance this pointer as we move in the K direction
  91. # and accumulate
  92. # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
  93. # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
  94. num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
  95. if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
  96. return
  97. offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  98. offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
  99. token_mask = offs_token < num_valid_tokens
  100. offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  101. offs_k = tl.arange(0, BLOCK_SIZE_K)
  102. a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
  103. offs_k[None, :] * stride_ak)
  104. off_experts = tl.load(expert_ids_ptr + pid_m)
  105. b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
  106. offs_bn[None, :] * stride_bn)
  107. if use_fp8:
  108. a_scale = tl.load(a_scale_ptr)
  109. b_scale = tl.load(b_scale_ptr + off_experts)
  110. # -----------------------------------------------------------
  111. # Iterate to compute a block of the C matrix.
  112. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
  113. # of fp32 values for higher accuracy.
  114. # `accumulator` will be converted back to fp16 after the loop.
  115. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  116. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  117. # Load the next block of A and B, generate a mask by checking the
  118. # K dimension.
  119. a = tl.load(a_ptrs,
  120. mask=token_mask[:, None] &
  121. (offs_k[None, :] < K - k * BLOCK_SIZE_K),
  122. other=0.0)
  123. b = tl.load(b_ptrs,
  124. mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
  125. other=0.0)
  126. # We accumulate along the K dimension.
  127. if use_fp8:
  128. accumulator = tl.dot(a, b, acc=accumulator)
  129. else:
  130. accumulator += tl.dot(a, b)
  131. # Advance the ptrs to the next K block.
  132. a_ptrs += BLOCK_SIZE_K * stride_ak
  133. b_ptrs += BLOCK_SIZE_K * stride_bk
  134. if MUL_ROUTED_WEIGHT:
  135. moe_weight = tl.load(topk_weights_ptr + offs_token,
  136. mask=token_mask,
  137. other=0)
  138. accumulator = accumulator * moe_weight[:, None]
  139. if use_fp8:
  140. accumulator = (accumulator * a_scale * b_scale).to(compute_type)
  141. else:
  142. accumulator = accumulator.to(compute_type)
  143. # -----------------------------------------------------------
  144. # Write back the block of the output
  145. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  146. c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
  147. None, :]
  148. c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
  149. tl.store(c_ptrs, accumulator, mask=c_mask)
  150. def moe_align_block_size(
  151. topk_ids: torch.Tensor, block_size: int,
  152. num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  153. """
  154. Aligns the token distribution across experts to be compatible with block
  155. size for matrix multiplication.
  156. Parameters:
  157. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
  158. top-k expert indices for each token.
  159. - block_size: The block size used in block matrix multiplication.
  160. - num_experts: The total number of experts.
  161. Returns:
  162. - sorted_token_ids: A tensor containing the sorted token indices according
  163. to their allocated expert.
  164. - expert_ids: A tensor indicating the assigned expert index for each block.
  165. - num_tokens_post_padded: The total number of tokens after padding,
  166. ensuring divisibility by block_size.
  167. This function pads the number of tokens that each expert needs to process
  168. so that it is divisible by block_size.
  169. Padding ensures that during block matrix multiplication, the dimensions
  170. align correctly.
  171. Example:
  172. Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
  173. block_size = 4, and num_experts = 4:
  174. - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
  175. with each expert needing to process 3 tokens.
  176. - As block_size is 4, we pad 1 token for each expert.
  177. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
  178. - Then append padding tokens [12, 12, 12, 12] for each block.
  179. - After sorting by expert index, we obtain token_ids
  180. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
  181. Tokens 12 are non-existent (padding) and are ignored in
  182. the subsequent matrix multiplication.
  183. - The padding ensures that the total number of tokens is now divisible
  184. by block_size for proper block matrix operations.
  185. """
  186. max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
  187. sorted_ids = torch.empty((max_num_tokens_padded, ),
  188. dtype=torch.int32,
  189. device=topk_ids.device)
  190. sorted_ids.fill_(topk_ids.numel())
  191. max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
  192. expert_ids = torch.empty((max_num_m_blocks, ),
  193. dtype=torch.int32,
  194. device=topk_ids.device)
  195. num_tokens_post_pad = torch.empty((1),
  196. dtype=torch.int32,
  197. device=topk_ids.device)
  198. ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
  199. expert_ids, num_tokens_post_pad)
  200. return sorted_ids, expert_ids, num_tokens_post_pad
  201. def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
  202. A_scale: Optional[torch.Tensor],
  203. B_scale: Optional[torch.Tensor],
  204. topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  205. sorted_token_ids: torch.Tensor,
  206. expert_ids: torch.Tensor,
  207. num_tokens_post_padded: torch.Tensor,
  208. mul_routed_weight: bool, top_k: int,
  209. config: Dict[str, Any], compute_type: tl.dtype,
  210. use_fp8: bool) -> None:
  211. assert topk_weights.stride(1) == 1
  212. assert sorted_token_ids.stride(0) == 1
  213. if not use_fp8:
  214. assert A_scale is None
  215. assert B_scale is None
  216. else:
  217. A, A_scale = ops.scaled_fp8_quant(A, A_scale)
  218. assert B_scale is not None
  219. grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
  220. 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
  221. fused_moe_kernel[grid](
  222. A,
  223. B,
  224. C,
  225. A_scale,
  226. B_scale,
  227. topk_weights,
  228. sorted_token_ids,
  229. expert_ids,
  230. num_tokens_post_padded,
  231. B.shape[1],
  232. B.shape[2],
  233. sorted_token_ids.shape[0],
  234. topk_ids.numel(),
  235. A.stride(0),
  236. A.stride(1),
  237. B.stride(0),
  238. B.stride(2),
  239. B.stride(1),
  240. C.stride(1),
  241. C.stride(2),
  242. MUL_ROUTED_WEIGHT=mul_routed_weight,
  243. top_k=top_k,
  244. compute_type=compute_type,
  245. use_fp8=use_fp8,
  246. **config,
  247. )
  248. def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
  249. device_name = torch.cuda.get_device_name().replace(" ", "_")
  250. dtype_selector = "" if not dtype else f",dtype={dtype}"
  251. return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
  252. @functools.lru_cache
  253. def get_moe_configs(E: int, N: int,
  254. dtype: Optional[str]) -> Optional[Dict[int, Any]]:
  255. """
  256. Return optimized configurations for the fused MoE kernel.
  257. The return value will be a dictionary that maps an irregular grid of
  258. batch sizes to configurations of the fused_moe kernel. To evaluate the
  259. kernel on a given batch size bs, the closest batch size in the grid should
  260. be picked and the associated configuration chosen to invoke the kernel.
  261. """
  262. # First look up if an optimized configuration is available in the configs
  263. # directory
  264. json_file_name = get_config_file_name(E, N, dtype)
  265. config_file_path = os.path.join(
  266. os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
  267. if os.path.exists(config_file_path):
  268. with open(config_file_path) as f:
  269. logger.info(f"Using configuration from {config_file_path} "
  270. "for MoE layer.")
  271. # If a configuration has been found, return it
  272. return {int(key): val for key, val in json.load(f).items()}
  273. # If no optimized configuration is available, we will use the default
  274. # configuration
  275. return None
  276. def get_default_config(
  277. M: int,
  278. E: int,
  279. N: int,
  280. K: int,
  281. topk: int,
  282. dtype: Optional[str],
  283. ) -> Dict[str, int]:
  284. config = {
  285. 'BLOCK_SIZE_M': 64,
  286. 'BLOCK_SIZE_N': 64,
  287. 'BLOCK_SIZE_K': 32,
  288. 'GROUP_SIZE_M': 8
  289. }
  290. if M <= E:
  291. config = {
  292. 'BLOCK_SIZE_M': 16,
  293. 'BLOCK_SIZE_N': 32,
  294. 'BLOCK_SIZE_K': 64,
  295. 'GROUP_SIZE_M': 1
  296. }
  297. return config
  298. def try_get_optimal_moe_config(
  299. w1_shape: Tuple[int, ...],
  300. w2_shape: Tuple[int, ...],
  301. top_k: int,
  302. dtype: Optional[str],
  303. M: int,
  304. override_config: Optional[Dict[str, Any]] = None,
  305. ):
  306. if override_config:
  307. config = override_config
  308. else:
  309. # First try to load optimal config from the file
  310. E, _, N = w2_shape
  311. configs = get_moe_configs(E, N, dtype)
  312. if configs:
  313. # If an optimal configuration map has been found, look up the
  314. # optimal config
  315. config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
  316. else:
  317. # Else use the default config
  318. config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
  319. return config
  320. def fused_topk(
  321. hidden_states: torch.Tensor,
  322. gating_output: torch.Tensor,
  323. topk: int,
  324. renormalize: bool,
  325. ):
  326. assert hidden_states.shape[0] == gating_output.shape[0], (
  327. "Number of tokens mismatch")
  328. M, _ = hidden_states.shape
  329. topk_weights = torch.empty(M,
  330. topk,
  331. dtype=torch.float32,
  332. device=hidden_states.device)
  333. topk_ids = torch.empty(M,
  334. topk,
  335. dtype=torch.int32,
  336. device=hidden_states.device)
  337. token_expert_indicies = torch.empty(M,
  338. topk,
  339. dtype=torch.int32,
  340. device=hidden_states.device)
  341. ops.topk_softmax(
  342. topk_weights,
  343. topk_ids,
  344. token_expert_indicies,
  345. gating_output.float(), # TODO: Optimize this.
  346. )
  347. del token_expert_indicies # Not used. Will be used in the future.
  348. if renormalize:
  349. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  350. return topk_weights, topk_ids
  351. # This is used by the Deepseek-V2 model
  352. def grouped_topk(hidden_states: torch.Tensor,
  353. gating_output: torch.Tensor,
  354. topk: int,
  355. renormalize: bool,
  356. num_expert_group: int = 0,
  357. topk_group: int = 0):
  358. assert hidden_states.shape[0] == gating_output.shape[0], (
  359. "Number of tokens mismatch")
  360. scores = torch.softmax(gating_output, dim=-1)
  361. num_token = scores.shape[0]
  362. group_scores = scores.view(num_token, num_expert_group,
  363. -1).max(dim=-1).values # [n, n_group]
  364. group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
  365. sorted=False)[1] # [n, top_k_group]
  366. group_mask = torch.zeros_like(group_scores) # [n, n_group]
  367. group_mask.scatter_(1, group_idx, 1) # [n, n_group]
  368. score_mask = group_mask.unsqueeze(-1).expand(
  369. num_token, num_expert_group,
  370. scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
  371. tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
  372. topk_weights, topk_ids = torch.topk(tmp_scores,
  373. k=topk,
  374. dim=-1,
  375. sorted=False)
  376. if renormalize:
  377. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  378. return topk_weights, topk_ids
  379. def fused_experts(hidden_states: torch.Tensor,
  380. w1: torch.Tensor,
  381. w2: torch.Tensor,
  382. topk_weights: torch.Tensor,
  383. topk_ids: torch.Tensor,
  384. inplace: bool = False,
  385. override_config: Optional[Dict[str, Any]] = None,
  386. use_fp8: bool = False,
  387. w1_scale: Optional[torch.Tensor] = None,
  388. w2_scale: Optional[torch.Tensor] = None,
  389. a1_scale: Optional[torch.Tensor] = None,
  390. a2_scale: Optional[torch.Tensor] = None):
  391. # Check constraints.
  392. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
  393. assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
  394. assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
  395. assert w1.is_contiguous(), "Expert weights1 must be contiguous"
  396. assert w2.is_contiguous(), "Expert weights2 must be contiguous"
  397. assert hidden_states.dtype in [
  398. torch.float32, torch.float16, torch.bfloat16
  399. ]
  400. num_tokens, _ = hidden_states.shape
  401. E, N, _ = w1.shape
  402. # We execute the fused_moe kernel in chunks.
  403. CHUNK_SIZE = APHRODITE_FUSED_MOE_CHUNK_SIZE
  404. M = min(num_tokens, CHUNK_SIZE)
  405. get_config_func = functools.partial(
  406. try_get_optimal_moe_config,
  407. w1.shape,
  408. w2.shape,
  409. topk_ids.shape[1],
  410. "float8" if use_fp8 else None,
  411. override_config=override_config,
  412. )
  413. config = get_config_func(M)
  414. intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
  415. device=hidden_states.device,
  416. dtype=hidden_states.dtype)
  417. intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
  418. device=hidden_states.device,
  419. dtype=hidden_states.dtype)
  420. intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
  421. device=hidden_states.device,
  422. dtype=hidden_states.dtype)
  423. compute_type = (tl.bfloat16
  424. if hidden_states.dtype == torch.bfloat16 else tl.float16)
  425. if inplace:
  426. out_hidden_states = hidden_states
  427. else:
  428. out_hidden_states = torch.empty_like(hidden_states)
  429. for chunk in range((num_tokens // CHUNK_SIZE) + 1):
  430. begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
  431. min((chunk + 1) * CHUNK_SIZE,
  432. num_tokens))
  433. curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
  434. tokens_in_chunk, _ = curr_hidden_states.shape
  435. if tokens_in_chunk == 0:
  436. break
  437. if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
  438. # Adjust the intermediate cache size and config for the last
  439. # chunk. Note that in most cases we only have one chunk
  440. # so the cache size and config are already set correctly and
  441. # do not need to be adjusted.
  442. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
  443. intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
  444. intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
  445. config = get_config_func(tokens_in_chunk)
  446. curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
  447. curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
  448. sorted_token_ids, expert_ids, num_tokens_post_padded = (
  449. moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
  450. invoke_fused_moe_kernel(curr_hidden_states,
  451. w1,
  452. intermediate_cache1,
  453. a1_scale,
  454. w1_scale,
  455. curr_topk_weights,
  456. curr_topk_ids,
  457. sorted_token_ids,
  458. expert_ids,
  459. num_tokens_post_padded,
  460. False,
  461. topk_ids.shape[1],
  462. config,
  463. compute_type=compute_type,
  464. use_fp8=use_fp8)
  465. ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
  466. invoke_fused_moe_kernel(intermediate_cache2,
  467. w2,
  468. intermediate_cache3,
  469. a2_scale,
  470. w2_scale,
  471. curr_topk_weights,
  472. curr_topk_ids,
  473. sorted_token_ids,
  474. expert_ids,
  475. num_tokens_post_padded,
  476. True,
  477. 1,
  478. config,
  479. compute_type=compute_type,
  480. use_fp8=use_fp8)
  481. torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  482. dim=1,
  483. out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
  484. return out_hidden_states
  485. def fused_moe(
  486. hidden_states: torch.Tensor,
  487. w1: torch.Tensor,
  488. w2: torch.Tensor,
  489. gating_output: torch.Tensor,
  490. topk: int,
  491. renormalize: bool,
  492. inplace: bool = False,
  493. override_config: Optional[Dict[str, Any]] = None,
  494. use_grouped_topk: bool = False,
  495. num_expert_group: Optional[int] = None,
  496. topk_group: Optional[int] = None,
  497. use_fp8: bool = False,
  498. w1_scale: Optional[torch.Tensor] = None,
  499. w2_scale: Optional[torch.Tensor] = None,
  500. a1_scale: Optional[torch.Tensor] = None,
  501. a2_scale: Optional[torch.Tensor] = None,
  502. ) -> torch.Tensor:
  503. """
  504. This function computes a Mixture of Experts (MoE) layer using two sets of
  505. weights, w1 and w2, and top-k gating mechanism.
  506. Parameters:
  507. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
  508. - w1 (torch.Tensor): The first set of expert weights.
  509. - w2 (torch.Tensor): The second set of expert weights.
  510. - gating_output (torch.Tensor): The output of the gating operation
  511. (before softmax).
  512. - topk (int): The number of top-k experts to select.
  513. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  514. - inplace (bool): If True, perform the operation in-place.
  515. Defaults to False.
  516. - override_config (Optional[Dict[str, Any]]): Optional override
  517. for the kernel configuration.
  518. - num_expert_group: Optional[int]: additional parameter for grouped_topk
  519. - topk_group: Optional[int]: additional parameter for grouped_topk
  520. - use_grouped_topk: If True, use grouped_topk instead of fused_topk
  521. note: Deepseekv2 model uses grouped_topk
  522. - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
  523. products for w1 and w2. Defaults to False.
  524. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
  525. w1.
  526. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
  527. w2.
  528. Returns:
  529. - torch.Tensor: The output tensor after applying the MoE layer.
  530. """
  531. # Check constraints.
  532. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
  533. if use_grouped_topk:
  534. assert num_expert_group is not None and topk_group is not None
  535. topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
  536. topk, renormalize,
  537. num_expert_group, topk_group)
  538. else:
  539. topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
  540. renormalize)
  541. return fused_experts(hidden_states,
  542. w1,
  543. w2,
  544. topk_weights,
  545. topk_ids,
  546. inplace=inplace,
  547. override_config=override_config,
  548. use_fp8=use_fp8,
  549. w1_scale=w1_scale,
  550. w2_scale=w2_scale,
  551. a1_scale=a1_scale,
  552. a2_scale=a2_scale)