moe.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. from typing import Tuple
  2. import torch
  3. from torch import nn
  4. import torch.nn.functional as F
  5. import triton
  6. import triton.language as tl
  7. from aphrodite._C import ops
  8. from aphrodite.modeling.layers.linear import ReplicatedLinear
  9. from aphrodite.modeling.megatron.communication_op import (
  10. tensor_model_parallel_all_reduce)
  11. from aphrodite.modeling.megatron.parallel_state import (
  12. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  13. from aphrodite.modeling.utils import set_weight_attrs
  14. class MoE(nn.Module):
  15. """A tensor parallel MOE that shards each expert across all ranks.
  16. Each expert's weights are sharded across all ranks. The forward pass
  17. will first expand and group the hidden states by experts, then compute
  18. the per-rank MLP output of each expert using grouped gemm, and finally
  19. reduce the output across ranks.
  20. """
  21. def __init__(
  22. self,
  23. num_experts: int,
  24. top_k: int,
  25. hidden_size: int,
  26. intermediate_size: int,
  27. ):
  28. super().__init__()
  29. tp_size = get_tensor_model_parallel_world_size()
  30. self.num_total_experts = num_experts
  31. self.top_k = top_k
  32. self.hidden_size = hidden_size
  33. self.intermediate_size = intermediate_size // tp_size
  34. self.gate = ReplicatedLinear(self.hidden_size,
  35. self.num_total_experts,
  36. bias=False,
  37. linear_method=None)
  38. self.w1s = nn.Parameter(
  39. torch.empty(self.num_total_experts,
  40. self.hidden_size,
  41. self.intermediate_size,
  42. device="cuda"))
  43. self.w2s = nn.Parameter(
  44. torch.empty(self.num_total_experts,
  45. self.intermediate_size,
  46. self.hidden_size,
  47. device="cuda"))
  48. self.w3s = nn.Parameter(
  49. torch.empty(self.num_total_experts,
  50. self.hidden_size,
  51. self.intermediate_size,
  52. device="cuda"))
  53. set_weight_attrs(self.w1s, {
  54. "weight_loader": self.weight_loader,
  55. "tp_type": "column"
  56. })
  57. set_weight_attrs(self.w2s, {
  58. "weight_loader": self.weight_loader,
  59. "tp_type": "row"
  60. })
  61. set_weight_attrs(self.w3s, {
  62. "weight_loader": self.weight_loader,
  63. "tp_type": "column"
  64. })
  65. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  66. expert_id: int):
  67. tp_rank = get_tensor_model_parallel_rank()
  68. loaded_weight = loaded_weight.t()
  69. # The parallel dimension is 1 for column-parallel, and 0 for
  70. # row-parallel.
  71. parallel_dim = 1 if getattr(param, "tp_type", None) == "column" else 0
  72. param_data = param.data
  73. shard_size = param_data.shape[parallel_dim + 1]
  74. start_idx = tp_rank * shard_size
  75. loaded_weight = loaded_weight.narrow(parallel_dim, start_idx,
  76. shard_size)
  77. assert param_data[expert_id].shape == loaded_weight.shape
  78. param_data[expert_id].copy_(loaded_weight)
  79. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  80. batch_size, sequence_length, hidden_size = hidden_states.shape
  81. hidden_states = hidden_states.view(-1, self.hidden_size)
  82. # router_logits: (batch * sequence_length, n_experts)
  83. router_logits, _ = self.gate(hidden_states)
  84. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  85. routing_weights, selected_experts = torch.topk(routing_weights,
  86. self.top_k,
  87. dim=-1)
  88. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  89. # Step 1: expand and permute hidden states and routing weights to group
  90. # hidden states by experts.
  91. (expanded_hidden_states, experts_range, expanded_weights,
  92. reverse_indices) = self.expand_and_permutate_hidden_states(
  93. hidden_states, selected_experts, routing_weights)
  94. # Step 2: compute the output of each expert.
  95. expanded_hidden_states = self.apply_experts_ffn(
  96. expanded_hidden_states, experts_range, self.w1s.data,
  97. self.w2s.data, self.w3s.data)
  98. # Step 3: apply weights to the output of each expert, and reduce
  99. # across ranks.
  100. expanded_hidden_states.mul_(expanded_weights.unsqueeze(-1))
  101. tensor_model_parallel_all_reduce(expanded_hidden_states)
  102. # Step 4: merge the output of each expert, according to the indices.
  103. return self.merge_expert_outputs(expanded_hidden_states,
  104. reverse_indices).view(
  105. batch_size, sequence_length,
  106. hidden_size)
  107. def expand_and_permutate_hidden_states(
  108. self,
  109. hidden_states: torch.Tensor,
  110. selected_experts: torch.Tensor,
  111. routing_weights: torch.Tensor,
  112. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  113. """Expand and group hidden states and routing weights according
  114. to the selected experts.
  115. Args:
  116. hidden_states (torch.Tensor): [batch_size, hidden_size]
  117. hidden states.
  118. selected_experts (torch.Tensor): [batch_size, top_k_experts]
  119. the indices of the selected experts.
  120. routing_weights (torch.Tensor): [batch_size, top_k_experts]
  121. the routing weights of the selected experts.
  122. Returns:
  123. expanded_hidden_states: [batch_size * top_k_experts, hidden_size]
  124. expanded hidden states that rows are grouped by experts.
  125. cum_experts_range: [num_experts + 1] the cumulative range of the
  126. experts in expanded_hidden_states, in the first dimension.
  127. expanded_weights: [batch_size * top_k_experts] the expanded
  128. expert weights for each row in expanded_hidden_states.
  129. reverse_indices: [batch_size * top_k_experts] the indices of each
  130. row in expanded_hidden_states which maps back to the original
  131. hidden states.
  132. """
  133. reverse_indices = torch.argsort(selected_experts.view(-1), dim=-1)
  134. cum_experts_range = torch.zeros(self.num_total_experts + 1,
  135. dtype=torch.int32,
  136. device=hidden_states.device)
  137. num_rows_per_expert = torch.zeros(self.num_total_experts,
  138. dtype=torch.int32,
  139. device=hidden_states.device)
  140. ops.bincount(selected_experts.view(-1), num_rows_per_expert)
  141. torch.cumsum(num_rows_per_expert, dim=0, out=cum_experts_range[1:])
  142. expanded_weights = routing_weights.view(-1)[reverse_indices]
  143. reverse_indices.div_(self.top_k, rounding_mode="floor")
  144. return (hidden_states[reverse_indices], cum_experts_range,
  145. expanded_weights, reverse_indices)
  146. def apply_experts_ffn(
  147. self,
  148. expanded_hidden_states: torch.
  149. Tensor, # [batch_size * top_k_experts, hidden_size]
  150. cum_experts_range: torch.Tensor, # [num_experts + 1]
  151. w1s: torch.Tensor, # [num_experts, hidden_size, ffn_dim]
  152. w2s: torch.Tensor, # [num_experts, ffn_dim, hidden_size]
  153. w3s: torch.Tensor, # [num_experts, hidden_size, ffn_dim]
  154. ) -> torch.Tensor: # [batch_size * top_k_experts, hidden_size]
  155. grouped_w1_out = grouped_matmul(expanded_hidden_states,
  156. cum_experts_range, w1s, "silu")
  157. grouped_w3_out = grouped_matmul(expanded_hidden_states,
  158. cum_experts_range, w3s)
  159. grouped_w1_out.mul_(grouped_w3_out)
  160. return grouped_matmul(grouped_w1_out, cum_experts_range, w2s)
  161. def merge_expert_outputs(
  162. self,
  163. expanded_hidden_states: torch.
  164. Tensor, # [batch_size * top_k_experts, hidden_size]
  165. reverse_indices, # [batch_size * top_k_experts]
  166. ) -> torch.Tensor:
  167. out = torch.zeros(expanded_hidden_states.shape[0] // self.top_k,
  168. self.hidden_size,
  169. device=expanded_hidden_states.device,
  170. dtype=expanded_hidden_states.dtype)
  171. out.index_add_(0, reverse_indices, expanded_hidden_states)
  172. return out
  173. # The following code is adapted from
  174. # https://github.com/openai/triton/blob/main/python/tutorials/11-grouped-gemm.py
  175. @triton.jit
  176. def grouped_matmul_kernel(
  177. # [batch_size, k], where each group are stored compactly in the batch
  178. # dimension. The range of each group is specified in cumulative_m_range.
  179. group_a_ptr,
  180. # [num_groups, k, n]
  181. group_b_ptr,
  182. # [batch_size, n], where each group are stored compactly in the batch
  183. # dimension. The range of each group is specified in cumulative_m_range.
  184. group_c_ptr,
  185. # num of gemm problems
  186. group_size,
  187. # for each gemm problem with size <m, n, k>, m is stored in
  188. # cumulative_m_range[i + i] - cumulative_m_range[i].
  189. # n and k are the same for all problems.
  190. cumulative_m_range,
  191. n,
  192. k,
  193. # group_a_ptr.stride(0)
  194. stride_a0,
  195. # group_b_ptr.stride(1)
  196. stride_b1,
  197. # group_c_ptr.stride(0)
  198. stride_c0,
  199. # number of virtual SM
  200. NUM_SM: tl.constexpr,
  201. # tile sizes
  202. BLOCK_SIZE_M: tl.constexpr,
  203. BLOCK_SIZE_N: tl.constexpr,
  204. BLOCK_SIZE_K: tl.constexpr,
  205. ACTIVATION: tl.constexpr,
  206. ):
  207. tile_idx = tl.program_id(0)
  208. last_problem_end = 0
  209. for g in range(group_size):
  210. # get the gemm size of the current problem
  211. a_offset = tl.load(cumulative_m_range + g)
  212. gm = tl.load(cumulative_m_range + g + 1) - a_offset
  213. gn = n
  214. gk = k
  215. num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
  216. num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
  217. num_tiles = num_m_tiles * num_n_tiles
  218. # iterate through the tiles in the current gemm problem
  219. # pylint: disable=chained-comparison
  220. while (tile_idx >= last_problem_end
  221. and tile_idx < last_problem_end + num_tiles):
  222. # pick up a tile from the current gemm problem
  223. k = gk
  224. a_ptr = group_a_ptr + a_offset * stride_a0
  225. b_ptr = group_b_ptr + g * k * n
  226. c_ptr = group_c_ptr + a_offset * stride_c0
  227. # figure out tile coordinates
  228. tile_idx_in_gemm = tile_idx - last_problem_end
  229. tile_m_idx = tile_idx_in_gemm // num_n_tiles
  230. tile_n_idx = tile_idx_in_gemm % num_n_tiles
  231. # do regular gemm here
  232. offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  233. offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  234. offs_k = tl.arange(0, BLOCK_SIZE_K)
  235. a_ptrs = a_ptr + offs_am[:, None] * stride_a0 + offs_k[None, :]
  236. b_ptrs = b_ptr + offs_k[:, None] * stride_b1 + offs_bn[None, :]
  237. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
  238. dtype=tl.float32)
  239. for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
  240. # hint to Triton compiler to do proper loop pipelining
  241. tl.multiple_of(a_ptrs, [16, 16])
  242. tl.multiple_of(b_ptrs, [16, 16])
  243. a = tl.load(a_ptrs,
  244. mask=(offs_k[None, :] < k - kk * BLOCK_SIZE_K) &
  245. (offs_am[:, None] < gm),
  246. other=0.0)
  247. b = tl.load(b_ptrs,
  248. mask=(offs_k[:, None] < k - kk * BLOCK_SIZE_K) &
  249. (offs_bn[None, :] < gn),
  250. other=0.0)
  251. accumulator += tl.dot(a, b)
  252. a_ptrs += BLOCK_SIZE_K
  253. b_ptrs += BLOCK_SIZE_K * stride_b1
  254. if ACTIVATION == "silu":
  255. accumulator = silu(accumulator)
  256. c = accumulator.to(group_c_ptr.dtype.element_ty)
  257. offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  258. offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  259. c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + offs_cn[None, :]
  260. c_mask = (offs_cm[:, None] < gm) & (offs_cn[None, :] < gn)
  261. tl.store(c_ptrs, c, mask=c_mask)
  262. # go to the next tile by advancing NUM_SM
  263. tile_idx += NUM_SM
  264. # get ready to go to the next gemm problem
  265. last_problem_end = last_problem_end + num_tiles
  266. @triton.jit
  267. def silu(x):
  268. return x * tl.sigmoid(x)
  269. def grouped_matmul(
  270. input: torch.Tensor, # pylint: disable=redefined-builtin
  271. cumulative_group_range: torch.Tensor,
  272. group_b_ptr: torch.Tensor,
  273. activation: str = ""):
  274. """Performs a grouped matrix-matrix product of matrices stored in input
  275. and group_b_ptr.
  276. input is a tensor of shape [batch_size, k] where each group are stored
  277. compactly in the batch dimension. The range of each group is specified
  278. in cumulative_group_range. This allows the input to have fixed shape
  279. regardless of the group sizes.
  280. Args:
  281. input (torch.Tensor): [batch_size, k] compact input.
  282. cumulative_group_range (torch.Tensor): [num_groups + 1] the cumulative
  283. range of the groups in input.
  284. group_b_ptr (torch.Tensor): [num_groups, k, n] the second matrix.
  285. activation (str, optional): "" or "silu". Defaults to "".
  286. Returns:
  287. torch.Tensor: [batch_size, n] compact output where groups
  288. are stored compactly in the batch dimension.
  289. """
  290. device = torch.device("cuda")
  291. assert cumulative_group_range.shape[0] == group_b_ptr.shape[0] + 1
  292. group_size = cumulative_group_range.shape[0] - 1
  293. output = torch.zeros(input.shape[0],
  294. group_b_ptr.shape[2],
  295. device=device,
  296. dtype=input.dtype)
  297. BLOCK_SIZE_M = 16
  298. BLOCK_SIZE_N = 64
  299. BLOCK_SIZE_K = 32
  300. num_warps = 2
  301. NUM_SM = 128
  302. num_stages = 5
  303. # hand tune the block size for different problem sizes.
  304. if input.shape[0] >= 8:
  305. num_warps = 4
  306. BLOCK_SIZE_N = 128
  307. if input.shape[0] >= 32:
  308. num_warps = 4
  309. BLOCK_SIZE_M = 32
  310. BLOCK_SIZE_N = 128
  311. # we use a fixed number of CTA, and it's auto-tunable
  312. grid = lambda META: (META["NUM_SM"], )
  313. grouped_matmul_kernel[grid](group_a_ptr=input,
  314. group_b_ptr=group_b_ptr,
  315. group_c_ptr=output,
  316. group_size=group_size,
  317. cumulative_m_range=cumulative_group_range,
  318. n=group_b_ptr.shape[2],
  319. k=group_b_ptr.shape[1],
  320. stride_a0=input.stride(0),
  321. stride_b1=group_b_ptr.stride(1),
  322. stride_c0=output.stride(0),
  323. ACTIVATION=activation,
  324. BLOCK_SIZE_M=BLOCK_SIZE_M,
  325. BLOCK_SIZE_N=BLOCK_SIZE_N,
  326. BLOCK_SIZE_K=BLOCK_SIZE_K,
  327. NUM_SM=NUM_SM,
  328. num_warps=num_warps,
  329. num_stages=num_stages)
  330. return output