"""Tests for the MOE layers. Run `pytest tests/kernels/test_moe.py`. """ import pytest import torch from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.fused_moe import fused_moe from aphrodite.modeling.models.mixtral import MixtralMoE def torch_moe(a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): out[mask] = SiluAndMul()( a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_moe( m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, ): a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 score = torch.randn((m, e), device='cuda', dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() def test_mixtral_moe(dtype: torch.dtype): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" # Instantiate our and huggingface's MoE blocks config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") aphrodite_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, params_dtype=dtype, tp_size=1, ).cuda() # Load the weights aphrodite_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) aphrodite_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) aphrodite_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") # vLLM uses 1D query [num_tokens, hidden_dim] aphrodite_inputs = hf_inputs.flatten(0, 1) # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) aphrodite_states = aphrodite_moe.forward(aphrodite_inputs) mixtral_moe_tol = { torch.float32: 1e-3, torch.float16: 1e-3, torch.bfloat16: 1e-2, } torch.testing.assert_close(hf_states.flatten(0, 1), aphrodite_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype])