test_moe.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """Tests for the MOE layers.
  2. Run `pytest tests/kernels/test_moe.py`.
  3. """
  4. import pytest
  5. import torch
  6. from transformers import MixtralConfig
  7. from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
  8. from aphrodite.modeling.layers.activation import SiluAndMul
  9. from aphrodite.modeling.layers.fused_moe import fused_moe
  10. from aphrodite.modeling.models.mixtral import MixtralMoE
  11. def torch_moe(a, w1, w2, score, topk):
  12. B, D = a.shape
  13. a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
  14. out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
  15. score = torch.softmax(score, dim=-1, dtype=torch.float32)
  16. topk_weight, topk_ids = torch.topk(score, topk)
  17. topk_weight = topk_weight.view(-1)
  18. topk_ids = topk_ids.view(-1)
  19. for i in range(w1.shape[0]):
  20. mask = topk_ids == i
  21. if mask.sum():
  22. out[mask] = SiluAndMul()(
  23. a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
  24. return (out.view(B, -1, w2.shape[1]) *
  25. topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
  26. @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
  27. @pytest.mark.parametrize("n", [2048, 256, 1024])
  28. @pytest.mark.parametrize("k", [128, 511, 1024])
  29. @pytest.mark.parametrize("e", [8, 64])
  30. @pytest.mark.parametrize("topk", [2, 6])
  31. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  32. def test_fused_moe(
  33. m: int,
  34. n: int,
  35. k: int,
  36. e: int,
  37. topk: int,
  38. dtype: torch.dtype,
  39. ):
  40. a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
  41. w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
  42. w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
  43. score = torch.randn((m, e), device='cuda', dtype=dtype)
  44. triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
  45. torch_output = torch_moe(a, w1, w2, score, topk)
  46. torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
  47. @pytest.mark.parametrize("dtype",
  48. [torch.float32, torch.float16, torch.bfloat16])
  49. @torch.inference_mode()
  50. def test_mixtral_moe(dtype: torch.dtype):
  51. """Make sure our Mixtral MoE implementation agrees with the one from
  52. huggingface."""
  53. # Instantiate our and huggingface's MoE blocks
  54. config = MixtralConfig()
  55. hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
  56. aphrodite_moe = MixtralMoE(
  57. num_experts=config.num_local_experts,
  58. top_k=config.num_experts_per_tok,
  59. hidden_size=config.hidden_size,
  60. intermediate_size=config.intermediate_size,
  61. params_dtype=dtype,
  62. tp_size=1,
  63. ).cuda()
  64. # Load the weights
  65. aphrodite_moe.gate.weight.data[:] = hf_moe.gate.weight.data
  66. for i in range(config.num_local_experts):
  67. weights = (hf_moe.experts[i].w1.weight.data,
  68. hf_moe.experts[i].w3.weight.data)
  69. aphrodite_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
  70. aphrodite_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
  71. # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
  72. hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
  73. # vLLM uses 1D query [num_tokens, hidden_dim]
  74. aphrodite_inputs = hf_inputs.flatten(0, 1)
  75. # Run forward passes for both MoE blocks
  76. hf_states, _ = hf_moe.forward(hf_inputs)
  77. aphrodite_states = aphrodite_moe.forward(aphrodite_inputs)
  78. mixtral_moe_tol = {
  79. torch.float32: 1e-3,
  80. torch.float16: 1e-3,
  81. torch.bfloat16: 1e-2,
  82. }
  83. torch.testing.assert_close(hf_states.flatten(0, 1),
  84. aphrodite_states,
  85. rtol=mixtral_moe_tol[dtype],
  86. atol=mixtral_moe_tol[dtype])