fused_moe.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. """Fused MoE kernel."""
  2. import functools
  3. import json
  4. import os
  5. from typing import Any, Dict, Optional, Tuple
  6. import torch
  7. import triton
  8. import triton.language as tl
  9. from loguru import logger
  10. import aphrodite.common.envs as envs
  11. from aphrodite import _custom_ops as ops
  12. from aphrodite.platforms import current_platform
  13. APHRODITE_FUSED_MOE_CHUNK_SIZE = envs.APHRODITE_FUSED_MOE_CHUNK_SIZE
  14. @triton.jit
  15. def fused_moe_kernel(
  16. # Pointers to matrices
  17. a_ptr,
  18. b_ptr,
  19. c_ptr,
  20. a_scale_ptr,
  21. b_scale_ptr,
  22. topk_weights_ptr,
  23. sorted_token_ids_ptr,
  24. expert_ids_ptr,
  25. num_tokens_post_padded_ptr,
  26. # Matrix dimensions
  27. N,
  28. K,
  29. EM,
  30. num_valid_tokens,
  31. # The stride variables represent how much to increase the ptr by when
  32. # moving by 1 element in a particular dimension. E.g. `stride_am` is
  33. # how much to increase `a_ptr` by to get the element one row down
  34. # (A has M rows).
  35. stride_am,
  36. stride_ak,
  37. stride_be,
  38. stride_bk,
  39. stride_bn,
  40. stride_cm,
  41. stride_cn,
  42. stride_bse,
  43. stride_bsn,
  44. # Meta-parameters
  45. BLOCK_SIZE_M: tl.constexpr,
  46. BLOCK_SIZE_N: tl.constexpr,
  47. BLOCK_SIZE_K: tl.constexpr,
  48. GROUP_SIZE_M: tl.constexpr,
  49. MUL_ROUTED_WEIGHT: tl.constexpr,
  50. top_k: tl.constexpr,
  51. compute_type: tl.constexpr,
  52. use_fp8_w8a8: tl.constexpr,
  53. use_int8_w8a16: tl.constexpr):
  54. """
  55. Implements the fused computation for a Mixture of Experts (MOE) using
  56. token and expert matrices.
  57. Key Parameters:
  58. - A: The input tensor representing tokens with shape (*, K), where '*' can
  59. be any shape representing batches and K is the feature dimension of
  60. each token.
  61. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
  62. the number of experts, K is the input feature dimension, and N is
  63. the output feature dimension.
  64. - C: The output cache tensor with shape (M, topk, N), where M is the
  65. total number of tokens post padding, topk is the number of times
  66. each token is repeated, and N is the output feature dimension.
  67. - sorted_token_ids: A tensor containing the sorted indices of tokens,
  68. repeated topk times and arranged by the expert index they are
  69. assigned to.
  70. - expert_ids: A tensor containing the indices of the expert for each
  71. block. It determines which expert matrix from B should be used for
  72. each block in A.
  73. This kernel performs the multiplication of a token by its corresponding
  74. expert matrix as determined by `expert_ids`. The sorting of
  75. `sorted_token_ids` by expert index and padding ensures divisibility by
  76. BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
  77. multiplication across different blocks processed by the same expert.
  78. """
  79. # -----------------------------------------------------------
  80. # Map program ids `pid` to the block of C it should compute.
  81. # This is done in a grouped ordering to promote L2 data reuse.
  82. pid = tl.program_id(axis=0)
  83. num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
  84. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  85. num_pid_in_group = GROUP_SIZE_M * num_pid_n
  86. group_id = pid // num_pid_in_group
  87. first_pid_m = group_id * GROUP_SIZE_M
  88. group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  89. pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
  90. pid_n = (pid % num_pid_in_group) // group_size_m
  91. # ----------------------------------------------------------
  92. # Create pointers for the first blocks of A and B.
  93. # We will advance this pointer as we move in the K direction
  94. # and accumulate
  95. # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
  96. # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
  97. num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
  98. if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
  99. return
  100. offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  101. offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
  102. token_mask = offs_token < num_valid_tokens
  103. offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  104. offs_k = tl.arange(0, BLOCK_SIZE_K)
  105. a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
  106. offs_k[None, :] * stride_ak)
  107. off_experts = tl.load(expert_ids_ptr + pid_m)
  108. b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
  109. offs_bn[None, :] * stride_bn)
  110. if use_int8_w8a16:
  111. b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
  112. None, :] * stride_bsn
  113. b_scale = tl.load(b_scale_ptrs)
  114. if use_fp8_w8a8:
  115. a_scale = tl.load(a_scale_ptr)
  116. b_scale = tl.load(b_scale_ptr + off_experts)
  117. # -----------------------------------------------------------
  118. # Iterate to compute a block of the C matrix.
  119. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
  120. # of fp32 values for higher accuracy.
  121. # `accumulator` will be converted back to fp16 after the loop.
  122. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  123. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  124. # Load the next block of A and B, generate a mask by checking the
  125. # K dimension.
  126. a = tl.load(a_ptrs,
  127. mask=token_mask[:, None] &
  128. (offs_k[None, :] < K - k * BLOCK_SIZE_K),
  129. other=0.0)
  130. b = tl.load(b_ptrs,
  131. mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
  132. other=0.0)
  133. # We accumulate along the K dimension.
  134. if use_int8_w8a16:
  135. accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
  136. elif use_fp8_w8a8:
  137. accumulator = tl.dot(a, b, acc=accumulator)
  138. else:
  139. accumulator += tl.dot(a, b)
  140. # Advance the ptrs to the next K block.
  141. a_ptrs += BLOCK_SIZE_K * stride_ak
  142. b_ptrs += BLOCK_SIZE_K * stride_bk
  143. if MUL_ROUTED_WEIGHT:
  144. moe_weight = tl.load(topk_weights_ptr + offs_token,
  145. mask=token_mask,
  146. other=0)
  147. accumulator = accumulator * moe_weight[:, None]
  148. if use_int8_w8a16:
  149. accumulator = (accumulator * b_scale).to(compute_type)
  150. elif use_fp8_w8a8:
  151. accumulator = (accumulator * a_scale * b_scale).to(compute_type)
  152. else:
  153. accumulator = accumulator.to(compute_type)
  154. # -----------------------------------------------------------
  155. # Write back the block of the output
  156. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  157. c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
  158. None, :]
  159. c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
  160. tl.store(c_ptrs, accumulator, mask=c_mask)
  161. def moe_align_block_size(
  162. topk_ids: torch.Tensor, block_size: int,
  163. num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  164. """
  165. Aligns the token distribution across experts to be compatible with block
  166. size for matrix multiplication.
  167. Parameters:
  168. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
  169. top-k expert indices for each token.
  170. - block_size: The block size used in block matrix multiplication.
  171. - num_experts: The total number of experts.
  172. Returns:
  173. - sorted_token_ids: A tensor containing the sorted token indices according
  174. to their allocated expert.
  175. - expert_ids: A tensor indicating the assigned expert index for each block.
  176. - num_tokens_post_padded: The total number of tokens after padding,
  177. ensuring divisibility by block_size.
  178. This function pads the number of tokens that each expert needs to process
  179. so that it is divisible by block_size.
  180. Padding ensures that during block matrix multiplication, the dimensions
  181. align correctly.
  182. Example:
  183. Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
  184. block_size = 4, and num_experts = 4:
  185. - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
  186. with each expert needing to process 3 tokens.
  187. - As block_size is 4, we pad 1 token for each expert.
  188. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
  189. - Then append padding tokens [12, 12, 12, 12] for each block.
  190. - After sorting by expert index, we obtain token_ids
  191. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
  192. Tokens 12 are non-existent (padding) and are ignored in
  193. the subsequent matrix multiplication.
  194. - The padding ensures that the total number of tokens is now divisible
  195. by block_size for proper block matrix operations.
  196. """
  197. max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
  198. sorted_ids = torch.empty((max_num_tokens_padded, ),
  199. dtype=torch.int32,
  200. device=topk_ids.device)
  201. sorted_ids.fill_(topk_ids.numel())
  202. max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
  203. expert_ids = torch.empty((max_num_m_blocks, ),
  204. dtype=torch.int32,
  205. device=topk_ids.device)
  206. num_tokens_post_pad = torch.empty((1),
  207. dtype=torch.int32,
  208. device=topk_ids.device)
  209. ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
  210. expert_ids, num_tokens_post_pad)
  211. return sorted_ids, expert_ids, num_tokens_post_pad
  212. def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
  213. A_scale: Optional[torch.Tensor],
  214. B_scale: Optional[torch.Tensor],
  215. topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  216. sorted_token_ids: torch.Tensor,
  217. expert_ids: torch.Tensor,
  218. num_tokens_post_padded: torch.Tensor,
  219. mul_routed_weight: bool, top_k: int,
  220. config: Dict[str, Any], compute_type: tl.dtype,
  221. use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
  222. assert topk_weights.stride(1) == 1
  223. assert sorted_token_ids.stride(0) == 1
  224. if use_fp8_w8a8:
  225. A, A_scale = ops.scaled_fp8_quant(A, A_scale)
  226. assert B_scale is not None
  227. elif use_int8_w8a16:
  228. assert B_scale is not None
  229. else:
  230. assert A_scale is None
  231. assert B_scale is None
  232. grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
  233. 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
  234. fused_moe_kernel[grid](
  235. A,
  236. B,
  237. C,
  238. A_scale,
  239. B_scale,
  240. topk_weights,
  241. sorted_token_ids,
  242. expert_ids,
  243. num_tokens_post_padded,
  244. B.shape[1],
  245. B.shape[2],
  246. sorted_token_ids.shape[0],
  247. topk_ids.numel(),
  248. A.stride(0),
  249. A.stride(1),
  250. B.stride(0),
  251. B.stride(2),
  252. B.stride(1),
  253. C.stride(1),
  254. C.stride(2),
  255. B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
  256. B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
  257. MUL_ROUTED_WEIGHT=mul_routed_weight,
  258. top_k=top_k,
  259. compute_type=compute_type,
  260. use_fp8_w8a8=use_fp8_w8a8,
  261. use_int8_w8a16=use_int8_w8a16,
  262. **config,
  263. )
  264. def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
  265. device_name = current_platform.get_device_name().replace(" ", "_")
  266. dtype_selector = "" if not dtype else f",dtype={dtype}"
  267. return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
  268. @functools.lru_cache
  269. def get_moe_configs(E: int, N: int,
  270. dtype: Optional[str]) -> Optional[Dict[int, Any]]:
  271. """
  272. Return optimized configurations for the fused MoE kernel.
  273. The return value will be a dictionary that maps an irregular grid of
  274. batch sizes to configurations of the fused_moe kernel. To evaluate the
  275. kernel on a given batch size bs, the closest batch size in the grid should
  276. be picked and the associated configuration chosen to invoke the kernel.
  277. """
  278. # First look up if an optimized configuration is available in the configs
  279. # directory
  280. json_file_name = get_config_file_name(E, N, dtype)
  281. config_file_path = os.path.join(
  282. os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
  283. if os.path.exists(config_file_path):
  284. with open(config_file_path) as f:
  285. logger.info(f"Using configuration from {config_file_path} "
  286. "for MoE layer.")
  287. # If a configuration has been found, return it
  288. return {int(key): val for key, val in json.load(f).items()}
  289. # If no optimized configuration is available, we will use the default
  290. # configuration
  291. return None
  292. def get_default_config(
  293. M: int,
  294. E: int,
  295. N: int,
  296. K: int,
  297. topk: int,
  298. dtype: Optional[str],
  299. ) -> Dict[str, int]:
  300. config = {
  301. 'BLOCK_SIZE_M': 64,
  302. 'BLOCK_SIZE_N': 64,
  303. 'BLOCK_SIZE_K': 32,
  304. 'GROUP_SIZE_M': 8
  305. }
  306. if M <= E:
  307. config = {
  308. 'BLOCK_SIZE_M': 16,
  309. 'BLOCK_SIZE_N': 32,
  310. 'BLOCK_SIZE_K': 64,
  311. 'GROUP_SIZE_M': 1
  312. }
  313. return config
  314. def try_get_optimal_moe_config(
  315. w1_shape: Tuple[int, ...],
  316. w2_shape: Tuple[int, ...],
  317. top_k: int,
  318. dtype: Optional[str],
  319. M: int,
  320. override_config: Optional[Dict[str, Any]] = None,
  321. ):
  322. if override_config:
  323. config = override_config
  324. else:
  325. # First try to load optimal config from the file
  326. E, _, N = w2_shape
  327. configs = get_moe_configs(E, N, dtype)
  328. if configs:
  329. # If an optimal configuration map has been found, look up the
  330. # optimal config
  331. config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
  332. else:
  333. # Else use the default config
  334. config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
  335. return config
  336. def fused_topk(
  337. hidden_states: torch.Tensor,
  338. gating_output: torch.Tensor,
  339. topk: int,
  340. renormalize: bool,
  341. ):
  342. assert hidden_states.shape[0] == gating_output.shape[0], (
  343. "Number of tokens mismatch")
  344. M, _ = hidden_states.shape
  345. topk_weights = torch.empty(M,
  346. topk,
  347. dtype=torch.float32,
  348. device=hidden_states.device)
  349. topk_ids = torch.empty(M,
  350. topk,
  351. dtype=torch.int32,
  352. device=hidden_states.device)
  353. token_expert_indicies = torch.empty(M,
  354. topk,
  355. dtype=torch.int32,
  356. device=hidden_states.device)
  357. ops.topk_softmax(
  358. topk_weights,
  359. topk_ids,
  360. token_expert_indicies,
  361. gating_output.float(), # TODO: Optimize this.
  362. )
  363. del token_expert_indicies # Not used. Will be used in the future.
  364. if renormalize:
  365. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  366. return topk_weights, topk_ids
  367. # This is used by the Deepseek-V2 model
  368. def grouped_topk(hidden_states: torch.Tensor,
  369. gating_output: torch.Tensor,
  370. topk: int,
  371. renormalize: bool,
  372. num_expert_group: int = 0,
  373. topk_group: int = 0):
  374. assert hidden_states.shape[0] == gating_output.shape[0], (
  375. "Number of tokens mismatch")
  376. scores = torch.softmax(gating_output, dim=-1)
  377. num_token = scores.shape[0]
  378. group_scores = scores.view(num_token, num_expert_group,
  379. -1).max(dim=-1).values # [n, n_group]
  380. group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
  381. sorted=False)[1] # [n, top_k_group]
  382. group_mask = torch.zeros_like(group_scores) # [n, n_group]
  383. group_mask.scatter_(1, group_idx, 1) # [n, n_group]
  384. score_mask = group_mask.unsqueeze(-1).expand(
  385. num_token, num_expert_group,
  386. scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
  387. tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
  388. topk_weights, topk_ids = torch.topk(tmp_scores,
  389. k=topk,
  390. dim=-1,
  391. sorted=False)
  392. if renormalize:
  393. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  394. return topk_weights, topk_ids
  395. def get_config_dtype_str(dtype: torch.dtype,
  396. use_int8_w8a16: Optional[bool] = False,
  397. use_fp8_w8a8: Optional[bool] = False):
  398. if use_fp8_w8a8:
  399. return "fp8_w8a8"
  400. elif use_int8_w8a16:
  401. return "int8_w8a16"
  402. elif dtype == torch.float:
  403. # avoiding cases where kernel fails when float32 MoE
  404. # use fp16/bfloat16 configs
  405. return "float32"
  406. return None
  407. def fused_experts(hidden_states: torch.Tensor,
  408. w1: torch.Tensor,
  409. w2: torch.Tensor,
  410. topk_weights: torch.Tensor,
  411. topk_ids: torch.Tensor,
  412. inplace: bool = False,
  413. override_config: Optional[Dict[str, Any]] = None,
  414. use_fp8_w8a8: bool = False,
  415. use_int8_w8a16: bool = False,
  416. w1_scale: Optional[torch.Tensor] = None,
  417. w2_scale: Optional[torch.Tensor] = None,
  418. a1_scale: Optional[torch.Tensor] = None,
  419. a2_scale: Optional[torch.Tensor] = None):
  420. # Check constraints.
  421. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
  422. assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
  423. assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
  424. assert w1.is_contiguous(), "Expert weights1 must be contiguous"
  425. assert w2.is_contiguous(), "Expert weights2 must be contiguous"
  426. assert hidden_states.dtype in [
  427. torch.float32, torch.float16, torch.bfloat16
  428. ]
  429. num_tokens, _ = hidden_states.shape
  430. E, N, _ = w1.shape
  431. # We execute the fused_moe kernel in chunks.
  432. CHUNK_SIZE = APHRODITE_FUSED_MOE_CHUNK_SIZE
  433. M = min(num_tokens, CHUNK_SIZE)
  434. config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
  435. use_int8_w8a16=use_int8_w8a16,
  436. dtype=hidden_states.dtype)
  437. get_config_func = functools.partial(
  438. try_get_optimal_moe_config,
  439. w1.shape,
  440. w2.shape,
  441. topk_ids.shape[1],
  442. config_dtype,
  443. override_config=override_config,
  444. )
  445. config = get_config_func(M)
  446. intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
  447. device=hidden_states.device,
  448. dtype=hidden_states.dtype)
  449. intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
  450. device=hidden_states.device,
  451. dtype=hidden_states.dtype)
  452. intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
  453. device=hidden_states.device,
  454. dtype=hidden_states.dtype)
  455. compute_type = (tl.bfloat16
  456. if hidden_states.dtype == torch.bfloat16 else tl.float16)
  457. if inplace:
  458. out_hidden_states = hidden_states
  459. else:
  460. out_hidden_states = torch.empty_like(hidden_states)
  461. for chunk in range((num_tokens // CHUNK_SIZE) + 1):
  462. begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
  463. min((chunk + 1) * CHUNK_SIZE,
  464. num_tokens))
  465. curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
  466. tokens_in_chunk, _ = curr_hidden_states.shape
  467. if tokens_in_chunk == 0:
  468. break
  469. if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
  470. # Adjust the intermediate cache size and config for the last
  471. # chunk. Note that in most cases we only have one chunk
  472. # so the cache size and config are already set correctly and
  473. # do not need to be adjusted.
  474. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
  475. intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
  476. intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
  477. config = get_config_func(tokens_in_chunk)
  478. curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
  479. curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
  480. sorted_token_ids, expert_ids, num_tokens_post_padded = (
  481. moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
  482. invoke_fused_moe_kernel(curr_hidden_states,
  483. w1,
  484. intermediate_cache1,
  485. a1_scale,
  486. w1_scale,
  487. curr_topk_weights,
  488. curr_topk_ids,
  489. sorted_token_ids,
  490. expert_ids,
  491. num_tokens_post_padded,
  492. False,
  493. topk_ids.shape[1],
  494. config,
  495. compute_type=compute_type,
  496. use_fp8_w8a8=use_fp8_w8a8,
  497. use_int8_w8a16=use_int8_w8a16)
  498. ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
  499. invoke_fused_moe_kernel(intermediate_cache2,
  500. w2,
  501. intermediate_cache3,
  502. a2_scale,
  503. w2_scale,
  504. curr_topk_weights,
  505. curr_topk_ids,
  506. sorted_token_ids,
  507. expert_ids,
  508. num_tokens_post_padded,
  509. True,
  510. 1,
  511. config,
  512. compute_type=compute_type,
  513. use_fp8_w8a8=use_fp8_w8a8,
  514. use_int8_w8a16=use_int8_w8a16)
  515. torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
  516. dim=1,
  517. out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
  518. return out_hidden_states
  519. def fused_moe(
  520. hidden_states: torch.Tensor,
  521. w1: torch.Tensor,
  522. w2: torch.Tensor,
  523. gating_output: torch.Tensor,
  524. topk: int,
  525. renormalize: bool,
  526. inplace: bool = False,
  527. override_config: Optional[Dict[str, Any]] = None,
  528. use_grouped_topk: bool = False,
  529. num_expert_group: Optional[int] = None,
  530. topk_group: Optional[int] = None,
  531. use_fp8_w8a8: bool = False,
  532. use_int8_w8a16: bool = False,
  533. w1_scale: Optional[torch.Tensor] = None,
  534. w2_scale: Optional[torch.Tensor] = None,
  535. a1_scale: Optional[torch.Tensor] = None,
  536. a2_scale: Optional[torch.Tensor] = None,
  537. ) -> torch.Tensor:
  538. """
  539. This function computes a Mixture of Experts (MoE) layer using two sets of
  540. weights, w1 and w2, and top-k gating mechanism.
  541. Parameters:
  542. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
  543. - w1 (torch.Tensor): The first set of expert weights.
  544. - w2 (torch.Tensor): The second set of expert weights.
  545. - gating_output (torch.Tensor): The output of the gating operation
  546. (before softmax).
  547. - topk (int): The number of top-k experts to select.
  548. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
  549. - inplace (bool): If True, perform the operation in-place.
  550. Defaults to False.
  551. - override_config (Optional[Dict[str, Any]]): Optional override
  552. for the kernel configuration.
  553. - num_expert_group: Optional[int]: additional parameter for grouped_topk
  554. - topk_group: Optional[int]: additional parameter for grouped_topk
  555. - use_grouped_topk: If True, use grouped_topk instead of fused_topk
  556. note: Deepseekv2 model uses grouped_topk
  557. - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
  558. products for w1 and w2. Defaults to False.
  559. - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
  560. products for w1 and w2. Defaults to False.
  561. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
  562. w1.
  563. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
  564. w2.
  565. Returns:
  566. - torch.Tensor: The output tensor after applying the MoE layer.
  567. """
  568. # Check constraints.
  569. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
  570. if use_grouped_topk:
  571. assert num_expert_group is not None and topk_group is not None
  572. topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
  573. topk, renormalize,
  574. num_expert_group, topk_group)
  575. else:
  576. topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
  577. renormalize)
  578. return fused_experts(hidden_states,
  579. w1,
  580. w2,
  581. topk_weights,
  582. topk_ids,
  583. inplace=inplace,
  584. override_config=override_config,
  585. use_fp8_w8a8=use_fp8_w8a8,
  586. use_int8_w8a16=use_int8_w8a16,
  587. w1_scale=w1_scale,
  588. w2_scale=w2_scale,
  589. a1_scale=a1_scale,
  590. a2_scale=a2_scale)