test_moe.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """Tests for the MOE layers.
  2. Run `pytest tests/kernels/test_moe.py`.
  3. """
  4. from typing import List
  5. import pytest
  6. import torch
  7. from transformers import MixtralConfig
  8. from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
  9. from aphrodite.common.utils import seed_everything
  10. from aphrodite.modeling.layers.activation import SiluAndMul
  11. from aphrodite.modeling.layers.fused_moe import fused_moe
  12. from aphrodite.modeling.layers.fused_moe.fused_marlin_moe import (
  13. fused_marlin_moe, single_marlin_moe)
  14. from aphrodite.modeling.layers.fused_moe.fused_moe import fused_topk
  15. from aphrodite.modeling.models.mixtral import MixtralMoE
  16. from aphrodite.quantization.utils.marlin_utils_test import marlin_quantize
  17. from aphrodite.scalar_type import scalar_types
  18. def torch_moe(a, w1, w2, score, topk):
  19. B, D = a.shape
  20. a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
  21. out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
  22. score = torch.softmax(score, dim=-1, dtype=torch.float32)
  23. topk_weight, topk_ids = torch.topk(score, topk)
  24. topk_weight = topk_weight.view(-1)
  25. topk_ids = topk_ids.view(-1)
  26. for i in range(w1.shape[0]):
  27. mask = topk_ids == i
  28. if mask.sum():
  29. out[mask] = SiluAndMul()(
  30. a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
  31. return (out.view(B, -1, w2.shape[1]) *
  32. topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
  33. def torch_moe_single(a, w, score, topk):
  34. B, D = a.shape
  35. a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
  36. out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
  37. score = torch.softmax(score, dim=-1, dtype=torch.float32)
  38. _, topk_ids = torch.topk(score, topk)
  39. topk_ids = topk_ids.view(-1)
  40. for i in range(w.shape[0]):
  41. mask = topk_ids == i
  42. if mask.sum():
  43. out[mask] = a[mask] @ w[i].transpose(0, 1)
  44. return (out.view(B, -1, w.shape[1])).sum(dim=1)
  45. @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
  46. @pytest.mark.parametrize("n", [2048, 256, 1024])
  47. @pytest.mark.parametrize("k", [128, 511, 1024])
  48. @pytest.mark.parametrize("e", [8, 64])
  49. @pytest.mark.parametrize("topk", [2, 6])
  50. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  51. def test_fused_moe(
  52. m: int,
  53. n: int,
  54. k: int,
  55. e: int,
  56. topk: int,
  57. dtype: torch.dtype,
  58. ):
  59. a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
  60. w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
  61. w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
  62. score = torch.randn((m, e), device="cuda", dtype=dtype)
  63. triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
  64. torch_output = torch_moe(a, w1, w2, score, topk)
  65. torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
  66. @pytest.mark.parametrize("dtype",
  67. [torch.float32, torch.float16, torch.bfloat16])
  68. @torch.inference_mode()
  69. def test_mixtral_moe(dtype: torch.dtype):
  70. """Make sure our Mixtral MoE implementation agrees with the one from
  71. huggingface."""
  72. # Instantiate our and huggingface's MoE blocks
  73. config = MixtralConfig()
  74. hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
  75. aphrodite_moe = MixtralMoE(
  76. num_experts=config.num_local_experts,
  77. top_k=config.num_experts_per_tok,
  78. hidden_size=config.hidden_size,
  79. intermediate_size=config.intermediate_size,
  80. params_dtype=dtype,
  81. tp_size=1,
  82. ).cuda()
  83. # Load the weights
  84. aphrodite_moe.gate.weight.data[:] = hf_moe.gate.weight.data
  85. for i in range(config.num_local_experts):
  86. weights = (hf_moe.experts[i].w1.weight.data,
  87. hf_moe.experts[i].w3.weight.data)
  88. aphrodite_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
  89. aphrodite_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
  90. # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
  91. hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
  92. # vLLM uses 1D query [num_tokens, hidden_dim]
  93. aphrodite_inputs = hf_inputs.flatten(0, 1)
  94. # Run forward passes for both MoE blocks
  95. hf_states, _ = hf_moe.forward(hf_inputs)
  96. aphrodite_states = aphrodite_moe.forward(aphrodite_inputs)
  97. mixtral_moe_tol = {
  98. torch.float32: 1e-3,
  99. torch.float16: 1e-3,
  100. torch.bfloat16: 1e-2,
  101. }
  102. torch.testing.assert_close(hf_states.flatten(0, 1),
  103. aphrodite_states,
  104. rtol=mixtral_moe_tol[dtype],
  105. atol=mixtral_moe_tol[dtype])
  106. def stack_and_dev(tensors: List[torch.Tensor]):
  107. dev = tensors[0].device
  108. return torch.stack(tensors, dim=0).to(dev)
  109. def compute_max_diff(output, output_ref):
  110. return torch.mean(torch.abs(output - output_ref)) / torch.mean(
  111. torch.abs(output_ref))
  112. @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
  113. @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
  114. @pytest.mark.parametrize("k", [128, 1024, 512])
  115. @pytest.mark.parametrize("e", [4, 8, 64])
  116. @pytest.mark.parametrize("topk", [2, 6])
  117. @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
  118. @pytest.mark.parametrize("act_order", [True, False])
  119. @pytest.mark.parametrize("num_bits", [4, 8])
  120. def test_fused_marlin_moe(
  121. m: int,
  122. n: int,
  123. k: int,
  124. e: int,
  125. topk: int,
  126. group_size: int,
  127. act_order: bool,
  128. num_bits: int,
  129. ):
  130. seed_everything(7)
  131. if topk > e:
  132. return
  133. # Filter act_order
  134. if act_order:
  135. if group_size == -1:
  136. return
  137. if group_size in (k, n):
  138. return
  139. quant_type = (scalar_types.uint4b8
  140. if num_bits == 4 else scalar_types.uint8b128)
  141. dtype = torch.float16
  142. a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
  143. w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
  144. w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
  145. w_ref1_l = []
  146. qweight1_l = []
  147. scales1_l = []
  148. g_idx1_l = []
  149. sort_indices1_l = []
  150. for i in range(w1.shape[0]):
  151. test_perm = torch.randperm(k)
  152. w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
  153. w1[i].transpose(1, 0), quant_type, group_size, act_order,
  154. test_perm)
  155. w_ref1_l.append(w_ref1)
  156. qweight1_l.append(qweight1)
  157. scales1_l.append(scales1)
  158. g_idx1_l.append(g_idx1)
  159. sort_indices1_l.append(sort_indices1)
  160. w_ref1 = stack_and_dev(w_ref1_l)
  161. qweight1 = stack_and_dev(qweight1_l).contiguous()
  162. scales1 = stack_and_dev(scales1_l)
  163. g_idx1 = stack_and_dev(g_idx1_l)
  164. sort_indices1 = stack_and_dev(sort_indices1_l)
  165. w_ref2_l = []
  166. qweight2_l = []
  167. scales2_l = []
  168. g_idx2_l = []
  169. sort_indices2_l = []
  170. for i in range(w2.shape[0]):
  171. test_perm = torch.randperm(n)
  172. w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
  173. w2[i].transpose(1, 0), quant_type, group_size, act_order,
  174. test_perm)
  175. w_ref2_l.append(w_ref2)
  176. qweight2_l.append(qweight2)
  177. scales2_l.append(scales2)
  178. g_idx2_l.append(g_idx2)
  179. sort_indices2_l.append(sort_indices2)
  180. w_ref2 = stack_and_dev(w_ref2_l)
  181. qweight2 = stack_and_dev(qweight2_l).contiguous()
  182. scales2 = stack_and_dev(scales2_l)
  183. g_idx2 = stack_and_dev(g_idx2_l)
  184. sort_indices2 = stack_and_dev(sort_indices2_l)
  185. score = torch.randn((m, e), device="cuda", dtype=dtype)
  186. topk_weights, topk_ids = fused_topk(a, score, topk, False)
  187. triton_output = fused_moe(
  188. a,
  189. w_ref1.transpose(1, 2).contiguous(),
  190. w_ref2.transpose(1, 2).contiguous(),
  191. score,
  192. topk,
  193. renormalize=False,
  194. )
  195. marlin_output = fused_marlin_moe(
  196. a,
  197. qweight1,
  198. qweight2,
  199. score,
  200. g_idx1,
  201. g_idx2,
  202. sort_indices1,
  203. sort_indices2,
  204. topk_weights,
  205. topk_ids,
  206. w1_scale=scales1,
  207. w2_scale=scales2,
  208. num_bits=num_bits,
  209. )
  210. assert compute_max_diff(marlin_output, triton_output) < 4e-2
  211. @pytest.mark.skip("This test is here for the sake of debugging, "
  212. "don't run it in automated tests.")
  213. @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
  214. @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
  215. @pytest.mark.parametrize("k", [128, 1024, 512])
  216. @pytest.mark.parametrize("e", [4, 8, 64])
  217. @pytest.mark.parametrize("topk", [2, 6])
  218. @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
  219. @pytest.mark.parametrize("act_order", [True, False])
  220. @pytest.mark.parametrize("num_bits", [4, 8])
  221. def test_single_marlin_moe_multiply(
  222. m: int,
  223. n: int,
  224. k: int,
  225. e: int,
  226. topk: int,
  227. group_size: int,
  228. act_order: bool,
  229. num_bits: int,
  230. ):
  231. if topk > e:
  232. return
  233. # Filter act_order
  234. if act_order:
  235. if group_size == -1:
  236. return
  237. if group_size == k:
  238. return
  239. quant_type = (scalar_types.uint4b8
  240. if num_bits == 4 else scalar_types.uint8b128)
  241. dtype = torch.float16
  242. a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
  243. w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
  244. w_ref_l = []
  245. qweights_l = []
  246. scales_l = []
  247. g_idx_l = []
  248. sort_indices_l = []
  249. for i in range(w.shape[0]):
  250. test_perm = torch.randperm(k)
  251. w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
  252. w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
  253. w_ref_l.append(w_ref)
  254. qweights_l.append(qweight)
  255. scales_l.append(scales)
  256. g_idx_l.append(g_idx)
  257. sort_indices_l.append(sort_indices)
  258. w_ref = stack_and_dev(w_ref_l)
  259. qweight = stack_and_dev(qweights_l).contiguous()
  260. scales = stack_and_dev(scales_l)
  261. g_idx = stack_and_dev(g_idx_l)
  262. sort_indices = stack_and_dev(sort_indices_l)
  263. score = torch.randn((m, e), device="cuda", dtype=dtype)
  264. marlin_output = single_marlin_moe(a,
  265. qweight,
  266. scales,
  267. score,
  268. g_idx,
  269. sort_indices,
  270. topk,
  271. renormalize=False,
  272. num_bits=num_bits)
  273. torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
  274. assert compute_max_diff(marlin_output, torch_output) < 1e-2