123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- """Tests for the MOE layers.
- Run `pytest tests/kernels/test_moe.py`.
- """
- import pytest
- import torch
- from transformers import MixtralConfig
- from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
- from aphrodite.modeling.layers.activation import SiluAndMul
- from aphrodite.modeling.layers.fused_moe import fused_moe
- from aphrodite.modeling.models.mixtral import MixtralMoE
- def torch_moe(a, w1, w2, score, topk):
- B, D = a.shape
- a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
- out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
- score = torch.softmax(score, dim=-1, dtype=torch.float32)
- topk_weight, topk_ids = torch.topk(score, topk)
- topk_weight = topk_weight.view(-1)
- topk_ids = topk_ids.view(-1)
- for i in range(w1.shape[0]):
- mask = topk_ids == i
- if mask.sum():
- out[mask] = SiluAndMul()(
- a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
- return (out.view(B, -1, w2.shape[1]) *
- topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
- @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
- @pytest.mark.parametrize("n", [2048, 256, 1024])
- @pytest.mark.parametrize("k", [128, 511, 1024])
- @pytest.mark.parametrize("e", [8, 64])
- @pytest.mark.parametrize("topk", [2, 6])
- @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
- def test_fused_moe(
- m: int,
- n: int,
- k: int,
- e: int,
- topk: int,
- dtype: torch.dtype,
- ):
- a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
- w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
- w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
- score = torch.randn((m, e), device='cuda', dtype=dtype)
- triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
- torch_output = torch_moe(a, w1, w2, score, topk)
- torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
- @pytest.mark.parametrize("dtype",
- [torch.float32, torch.float16, torch.bfloat16])
- @torch.inference_mode()
- def test_mixtral_moe(dtype: torch.dtype):
- """Make sure our Mixtral MoE implementation agrees with the one from
- huggingface."""
- # Instantiate our and huggingface's MoE blocks
- config = MixtralConfig()
- hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
- aphrodite_moe = MixtralMoE(
- num_experts=config.num_local_experts,
- top_k=config.num_experts_per_tok,
- hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size,
- params_dtype=dtype,
- tp_size=1,
- ).cuda()
- # Load the weights
- aphrodite_moe.gate.weight.data[:] = hf_moe.gate.weight.data
- for i in range(config.num_local_experts):
- weights = (hf_moe.experts[i].w1.weight.data,
- hf_moe.experts[i].w3.weight.data)
- aphrodite_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
- aphrodite_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
- # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
- hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
- # vLLM uses 1D query [num_tokens, hidden_dim]
- aphrodite_inputs = hf_inputs.flatten(0, 1)
- # Run forward passes for both MoE blocks
- hf_states, _ = hf_moe.forward(hf_inputs)
- aphrodite_states = aphrodite_moe.forward(aphrodite_inputs)
- mixtral_moe_tol = {
- torch.float32: 1e-3,
- torch.float16: 1e-3,
- torch.bfloat16: 1e-2,
- }
- torch.testing.assert_close(hf_states.flatten(0, 1),
- aphrodite_states,
- rtol=mixtral_moe_tol[dtype],
- atol=mixtral_moe_tol[dtype])
|