|
@@ -27,15 +27,12 @@ import torch
|
|
|
from torch import nn
|
|
|
from transformers import MixtralConfig
|
|
|
|
|
|
-from aphrodite import _custom_ops as ops
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
|
from aphrodite.common.config import CacheConfig, LoRAConfig
|
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
from aphrodite.common.utils import print_warning_once
|
|
|
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
|
|
|
- get_tensor_model_parallel_world_size,
|
|
|
- tensor_model_parallel_all_reduce)
|
|
|
-from aphrodite.modeling.layers.fused_moe import fused_moe
|
|
|
+from aphrodite.distributed import get_tensor_model_parallel_world_size
|
|
|
+from aphrodite.modeling.layers.fused_moe import FusedMoE
|
|
|
from aphrodite.modeling.layers.layernorm import RMSNorm
|
|
|
from aphrodite.modeling.layers.linear import (QKVParallelLinear,
|
|
|
ReplicatedLinear,
|
|
@@ -46,12 +43,10 @@ from aphrodite.modeling.layers.sampler import Sampler
|
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import (
|
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
-from aphrodite.modeling.models.interfaces import SupportsLoRA
|
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
-from aphrodite.modeling.utils import set_weight_attrs
|
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
-from aphrodite.quantization.fp8 import (Fp8Config, per_tensor_dequantize,
|
|
|
- per_tensor_quantize)
|
|
|
+
|
|
|
+from .interfaces import SupportsLoRA
|
|
|
|
|
|
|
|
|
class MixtralMoE(nn.Module):
|
|
@@ -63,240 +58,55 @@ class MixtralMoE(nn.Module):
|
|
|
across ranks.
|
|
|
"""
|
|
|
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- num_experts: int,
|
|
|
- top_k: int,
|
|
|
- hidden_size: int,
|
|
|
- intermediate_size: int,
|
|
|
- params_dtype: Optional[torch.dtype] = None,
|
|
|
- tp_size: Optional[int] = None,
|
|
|
- quant_config: Optional[QuantizationConfig] = None,
|
|
|
- ):
|
|
|
+ def __init__(self,
|
|
|
+ num_experts: int,
|
|
|
+ top_k: int,
|
|
|
+ hidden_size: int,
|
|
|
+ intermediate_size: int,
|
|
|
+ params_dtype: Optional[torch.dtype] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ tp_size: Optional[int] = None):
|
|
|
super().__init__()
|
|
|
- self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
|
|
- self.num_total_experts = num_experts
|
|
|
- self.top_k = top_k
|
|
|
self.hidden_size = hidden_size
|
|
|
- self.intermediate_size = intermediate_size // self.tp_size
|
|
|
- self.quant_config = quant_config
|
|
|
-
|
|
|
- # FIXME(pcmoritz): Make this more general to support different
|
|
|
- # quantization schemes
|
|
|
- self.use_fp8 = isinstance(quant_config, Fp8Config)
|
|
|
-
|
|
|
- if params_dtype is None:
|
|
|
- params_dtype = torch.get_default_dtype()
|
|
|
- self.params_dtype = params_dtype
|
|
|
|
|
|
# Gate always runs at half / full precision for now.
|
|
|
- self.gate = ReplicatedLinear(self.hidden_size,
|
|
|
- self.num_total_experts,
|
|
|
+ self.gate = ReplicatedLinear(hidden_size,
|
|
|
+ num_experts,
|
|
|
bias=False,
|
|
|
- params_dtype=self.params_dtype,
|
|
|
+ params_dtype=params_dtype,
|
|
|
quant_config=None)
|
|
|
|
|
|
- if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
|
|
- params_dtype = torch.float8_e4m3fn
|
|
|
-
|
|
|
- self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
|
|
- 2 * self.intermediate_size,
|
|
|
- self.hidden_size,
|
|
|
- dtype=params_dtype),
|
|
|
- requires_grad=False)
|
|
|
- self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
|
|
- self.hidden_size,
|
|
|
- self.intermediate_size,
|
|
|
- dtype=params_dtype),
|
|
|
- requires_grad=False)
|
|
|
-
|
|
|
- set_weight_attrs(self.w13_weight, {
|
|
|
- "weight_loader": self.weight_loader,
|
|
|
- })
|
|
|
- set_weight_attrs(self.w2_weight, {
|
|
|
- "weight_loader": self.weight_loader,
|
|
|
- })
|
|
|
-
|
|
|
- # Used for fp8.
|
|
|
- self.w13_scale = None
|
|
|
- self.w2_scale = None
|
|
|
- self.a13_scale = None
|
|
|
- self.a2_scale = None
|
|
|
-
|
|
|
- if self.use_fp8:
|
|
|
- # WEIGHT_SCALE (for fp8)
|
|
|
- # Allocate 2 scales for w1 and w3 respectively.
|
|
|
- # They will be combined to a single scale after weight loading.
|
|
|
- self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
- 2,
|
|
|
- dtype=torch.float32),
|
|
|
- requires_grad=False)
|
|
|
- self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
- dtype=torch.float32),
|
|
|
- requires_grad=False)
|
|
|
-
|
|
|
- # If loading fp8 checkpoint, pass the weight loaders.
|
|
|
- # If loading an fp16 checkpoint, do not (we will quantize in
|
|
|
- # process_weights_after_loading()
|
|
|
- if quant_config.is_checkpoint_fp8_serialized:
|
|
|
- set_weight_attrs(self.w13_scale, {
|
|
|
- "weight_loader": self.weight_loader,
|
|
|
- })
|
|
|
- set_weight_attrs(self.w2_scale, {
|
|
|
- "weight_loader": self.weight_loader,
|
|
|
- })
|
|
|
-
|
|
|
- # ACT_SCALE (for fp8)
|
|
|
- if quant_config.activation_scheme == "static":
|
|
|
- if not quant_config.is_checkpoint_fp8_serialized:
|
|
|
- raise ValueError(
|
|
|
- "Found static activation scheme for checkpoint that "
|
|
|
- "was not serialized fp8.")
|
|
|
- self.a13_scale = nn.Parameter(torch.ones(
|
|
|
- self.num_total_experts, dtype=torch.float32),
|
|
|
- requires_grad=False)
|
|
|
- self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
- dtype=torch.float32),
|
|
|
- requires_grad=False)
|
|
|
-
|
|
|
- set_weight_attrs(self.a13_scale, {
|
|
|
- "weight_loader": self.weight_loader,
|
|
|
- })
|
|
|
- set_weight_attrs(self.a2_scale, {
|
|
|
- "weight_loader": self.weight_loader,
|
|
|
- })
|
|
|
-
|
|
|
- def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
|
|
- weight_name: str, expert_id: int):
|
|
|
- tp_rank = get_tensor_model_parallel_rank()
|
|
|
- param_data = param.data
|
|
|
- shard_size = self.intermediate_size
|
|
|
- shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
|
- if weight_name.endswith("w1.weight"):
|
|
|
- param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
|
|
- if weight_name.endswith("w3.weight"):
|
|
|
- param_data[expert_id,
|
|
|
- shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
|
|
- if weight_name.endswith("w2.weight"):
|
|
|
- param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
|
-
|
|
|
- # Loading scales
|
|
|
- if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
|
|
|
- if param_data[expert_id] != 1 and (param_data[expert_id] -
|
|
|
- loaded_weight).abs() > 1e-5:
|
|
|
- raise ValueError(
|
|
|
- "act_scales of w1 and w3 of a layer "
|
|
|
- f"must be equal. But got {param_data[expert_id]} "
|
|
|
- f"vs. {loaded_weight}")
|
|
|
- param_data[expert_id] = loaded_weight
|
|
|
- elif "weight_scale" in weight_name:
|
|
|
- # We have to keep the weight scales of w1 and w3 because
|
|
|
- # we need to re-quantize w1/w3 weights after weight loading.
|
|
|
- assert "w1" in weight_name or "w3" in weight_name
|
|
|
- shard_id = 0 if "w1" in weight_name else 1
|
|
|
- param_data[expert_id][shard_id] = loaded_weight
|
|
|
-
|
|
|
- def process_weights_after_loading(self):
|
|
|
- # Fp8 is the only case where we need to process after loading.
|
|
|
- if not self.use_fp8:
|
|
|
- return
|
|
|
-
|
|
|
- # If checkpoint is fp16, quantize here.
|
|
|
- if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
|
- w13_weight = torch.empty_like(self.w13_weight.data,
|
|
|
- dtype=torch.float8_e4m3fn)
|
|
|
- w2_weight = torch.empty_like(self.w2_weight.data,
|
|
|
- dtype=torch.float8_e4m3fn)
|
|
|
-
|
|
|
- # Re-initialize w13_scale because we directly quantize
|
|
|
- # merged w13 weights and generate a single scaling factor.
|
|
|
- self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
|
- dtype=torch.float32),
|
|
|
- requires_grad=False)
|
|
|
- for expert in range(self.num_total_experts):
|
|
|
- w13_weight[expert, :, :], self.w13_scale[
|
|
|
- expert] = ops.scaled_fp8_quant(
|
|
|
- self.w13_weight.data[expert, :, :])
|
|
|
- w2_weight[expert, :, :], self.w2_scale[
|
|
|
- expert] = ops.scaled_fp8_quant(
|
|
|
- self.w2_weight.data[expert, :, :])
|
|
|
- self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
|
|
- self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
|
|
-
|
|
|
- else:
|
|
|
- # If checkpoint is fp8 + static, cleanup act_scales.
|
|
|
- # Since state_dict has an input_scale per expert but our kernels
|
|
|
- # are passed one input_scale shared across all experts.
|
|
|
- if self.quant_config.activation_scheme == "static":
|
|
|
- if self.a13_scale is None or self.a2_scale is None:
|
|
|
- raise ValueError(
|
|
|
- "QuantConfig has static quantization, but found "
|
|
|
- "activation scales are None.")
|
|
|
-
|
|
|
- if (not all_close_1d(self.a13_scale)
|
|
|
- or not all_close_1d(self.a2_scale)):
|
|
|
- print_warning_once(
|
|
|
- "Found act_scales that are not equal for "
|
|
|
- "fp8 MoE layer. Using the maximum across experts "
|
|
|
- "for each layer. ")
|
|
|
-
|
|
|
- self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
|
|
- requires_grad=False)
|
|
|
- self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
|
|
- requires_grad=False)
|
|
|
-
|
|
|
- assert self.w13_scale is not None
|
|
|
- shard_size = self.intermediate_size
|
|
|
- max_w13_scales = self.w13_scale.max(dim=1).values
|
|
|
- for expert_id in range(self.num_total_experts):
|
|
|
- start = 0
|
|
|
- for shard_id in range(2):
|
|
|
- dq_weight = per_tensor_dequantize(
|
|
|
- self.w13_weight[expert_id][start:start +
|
|
|
- shard_size, :],
|
|
|
- self.w13_scale[expert_id][shard_id])
|
|
|
- self.w13_weight[expert_id][
|
|
|
- start:start + shard_size, :] = per_tensor_quantize(
|
|
|
- dq_weight, max_w13_scales[expert_id])
|
|
|
- start += shard_size
|
|
|
-
|
|
|
- self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
|
|
|
+ self.experts = FusedMoE(num_experts=num_experts,
|
|
|
+ top_k=top_k,
|
|
|
+ hidden_size=hidden_size,
|
|
|
+ intermediate_size=intermediate_size,
|
|
|
+ params_dtype=params_dtype,
|
|
|
+ reduce_results=True,
|
|
|
+ renormalize=True,
|
|
|
+ quant_config=quant_config,
|
|
|
+ tp_size=tp_size)
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
|
|
# router_logits: (num_tokens, n_experts)
|
|
|
router_logits, _ = self.gate(hidden_states)
|
|
|
- final_hidden_states = fused_moe(hidden_states,
|
|
|
- self.w13_weight,
|
|
|
- self.w2_weight,
|
|
|
- router_logits,
|
|
|
- self.top_k,
|
|
|
- renormalize=True,
|
|
|
- inplace=True,
|
|
|
- use_fp8=self.use_fp8,
|
|
|
- w1_scale=self.w13_scale,
|
|
|
- w2_scale=self.w2_scale,
|
|
|
- a1_scale=self.a13_scale,
|
|
|
- a2_scale=self.a2_scale)
|
|
|
-
|
|
|
- if self.tp_size > 1:
|
|
|
- final_hidden_states = tensor_model_parallel_all_reduce(
|
|
|
- final_hidden_states)
|
|
|
-
|
|
|
+ final_hidden_states = self.experts(hidden_states, router_logits)
|
|
|
return final_hidden_states.view(num_tokens, hidden_size)
|
|
|
|
|
|
|
|
|
class MixtralAttention(nn.Module):
|
|
|
|
|
|
- def __init__(self,
|
|
|
- hidden_size: int,
|
|
|
- num_heads: int,
|
|
|
- num_kv_heads: int,
|
|
|
- max_position: int = 4096 * 32,
|
|
|
- rope_theta: float = 10000,
|
|
|
- cache_config: Optional[CacheConfig] = None,
|
|
|
- quant_config: Optional[QuantizationConfig] = None) -> None:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ hidden_size: int,
|
|
|
+ num_heads: int,
|
|
|
+ num_kv_heads: int,
|
|
|
+ max_position: int = 4096 * 32,
|
|
|
+ rope_theta: float = 10000,
|
|
|
+ cache_config: Optional[CacheConfig] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ) -> None:
|
|
|
super().__init__()
|
|
|
self.hidden_size = hidden_size
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
@@ -501,8 +311,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|
|
lora_config: Optional[LoRAConfig] = None,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
+
|
|
|
self.config = config
|
|
|
self.lora_config = lora_config
|
|
|
+
|
|
|
self.model = MixtralModel(config,
|
|
|
cache_config,
|
|
|
quant_config,
|
|
@@ -559,25 +371,28 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|
|
|
|
|
expert_params_mapping = [
|
|
|
# These are the weight scales for the experts
|
|
|
- # (param_name, weight_name, expert_id)
|
|
|
- ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
|
|
- f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
|
|
- for expert_id in range(self.config.num_local_experts)
|
|
|
- for weight_name in ["w1", "w2", "w3"]
|
|
|
+ # (param_name, weight_name, expert_id, shard_id)
|
|
|
+ ("experts.w13_scale"
|
|
|
+ if weight_name in ["w1", "w3"] else "experts.w2_scale",
|
|
|
+ f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
|
|
+ shard_id) for expert_id in range(self.config.num_local_experts)
|
|
|
+ for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
|
|
] + [
|
|
|
# These are the weights for the experts
|
|
|
# (param_name, weight_name, expert_id)
|
|
|
- ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
|
|
- f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
|
|
+ ("experts.w13_weight"
|
|
|
+ if weight_name in ["w1", "w3"] else "experts.w2_weight",
|
|
|
+ f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
|
|
for expert_id in range(self.config.num_local_experts)
|
|
|
- for weight_name in ["w1", "w2", "w3"]
|
|
|
+ for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
|
|
] + [
|
|
|
# These are the activation scales for the experts
|
|
|
# (param_name, weight_name, expert_id)
|
|
|
- ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
|
|
- f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
|
|
|
- for expert_id in range(self.config.num_local_experts)
|
|
|
- for weight_name in ["w1", "w2", "w3"]
|
|
|
+ ("experts.a13_scale"
|
|
|
+ if weight_name in ["w1", "w3"] else "experts.a2_scale",
|
|
|
+ f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
|
|
+ shard_id) for expert_id in range(self.config.num_local_experts)
|
|
|
+ for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
|
|
]
|
|
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
@@ -597,7 +412,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
|
break
|
|
|
else:
|
|
|
- for param_name, weight_name, expert_id in expert_params_mapping:
|
|
|
+ for mapping in expert_params_mapping:
|
|
|
+ param_name, weight_name, expert_id, shard_id = mapping
|
|
|
if weight_name not in name:
|
|
|
continue
|
|
|
name = name.replace(weight_name, param_name)
|
|
@@ -606,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|
|
weight_loader(param,
|
|
|
loaded_weight,
|
|
|
weight_name,
|
|
|
+ shard_id=shard_id,
|
|
|
expert_id=expert_id)
|
|
|
break
|
|
|
else:
|
|
@@ -630,8 +447,3 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|
|
weight_loader = getattr(param, "weight_loader",
|
|
|
default_weight_loader)
|
|
|
weight_loader(param, loaded_weight)
|
|
|
-
|
|
|
-
|
|
|
-def all_close_1d(x: torch.Tensor) -> bool:
|
|
|
- assert len(x.shape) == 1
|
|
|
- return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|