# coding=utf-8 from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig from aphrodite.common.sequence import IntermediateTensors, SamplerOutput from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from aphrodite.modeling.layers.fused_moe import fused_moe from aphrodite.modeling.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.utils import set_weight_attrs from aphrodite.quantization.base_config import QuantizationConfig from aphrodite.transformers_utils.configs.dbrx import DbrxConfig class DbrxRouter(nn.Module): """A Router implementation for DBRX that returns logits for each expert per token. """ def __init__( self, config: DbrxConfig, params_dtype: Optional[torch.dtype] = None, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.ffn_config.moe_num_experts self.d_model = config.d_model self.layer = ReplicatedLinear( self.d_model, self.num_total_experts, bias=False, params_dtype=params_dtype, quant_config=None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.layer(hidden_states) return router_logits class DbrxExperts(nn.Module): """A tensor-parallel MoE implementation for DBRX. Each expert's weights are sharded across all ranks and a fused MoE kernel is used for the forward pass, and finally we reduce the outputs across ranks. """ def __init__( self, config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.ffn_config.moe_num_experts self.top_k = config.ffn_config.moe_top_k self.d_model = config.d_model self.intermediate_size = (config.ffn_config.ffn_hidden_size // self.tp_size) if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype self.router = DbrxRouter(config, self.params_dtype) self.ws = nn.Parameter( torch.empty( self.num_total_experts, 2 * self.intermediate_size, self.d_model, device="cuda", dtype=self.params_dtype, )) self.w2s = nn.Parameter( torch.empty( self.num_total_experts, self.d_model, self.intermediate_size, device="cuda", dtype=self.params_dtype, )) set_weight_attrs( self.ws, { "weight_loader": self.weight_loader, }, ) set_weight_attrs( self.w2s, { "weight_loader": self.weight_loader, }, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) # DBRX uses GLU for each experts. # GLU has 3 linear layers: w1, v1 and w2. if weight_name.endswith("w1"): loaded_weight = torch.reshape( loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ) param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :] if weight_name.endswith("v1"): loaded_weight = torch.reshape( loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ) param_data[:, shard_size:2 * shard_size, :] = loaded_weight[:, shard, :] if weight_name.endswith("w2"): loaded_weight = torch.reshape( loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ).transpose(1, 2) param_data[:] = loaded_weight[:, :, shard] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.d_model) # router_logits: (num_tokens, n_experts) router_logits = self.router(hidden_states) final_hidden_states = fused_moe( hidden_states, self.ws, self.w2s, router_logits, self.top_k, renormalize=True, inplace=True, ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads self.head_dim = self.d_model // self.total_num_heads self.total_num_kv_heads = config.attn_config.kv_n_heads self.clip_qkv = config.attn_config.clip_qkv self.rope_theta = config.attn_config.rope_theta self.max_position = config.max_seq_len # pylint: disable=invalid-name self.Wqkv = QKVParallelLinear( self.d_model, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, ) self.out_proj = RowParallelLinear( self.d_model, self.d_model, bias=False, quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=self.max_position, base=int(self.rope_theta), is_neox_style=True, ) tp_world_size = get_tensor_model_parallel_world_size() self.tp_size = tp_world_size assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size if self.total_num_kv_heads >= tp_world_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_world_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_world_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(position_ids, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) hidden_states, _ = self.out_proj(attn_output) return hidden_states class DbrxFusedNormAttention(nn.Module): def __init__( self, config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model self.attn = DbrxAttention(config, cache_config, quant_config) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm_1(hidden_states) x = self.attn( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = residual + x residual = hidden_states hidden_states = self.norm_2(hidden_states) return hidden_states, residual class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, quant_config) self.ffn = DbrxExperts(config, quant_config) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states, residual = self.norm_attn_norm( position_ids=position_ids, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = self.ffn(hidden_states) hidden_states = hidden_states + residual return hidden_states class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, ) self.blocks = nn.ModuleList([ DbrxBlock(config, cache_config, quant_config) for _ in range(config.n_layers) ]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.blocks)): block = self.blocks[i] hidden_states = block( position_ids, hidden_states, kv_caches[i], attn_metadata, ) hidden_states = self.norm_f(hidden_states) return hidden_states class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size self.transformer = DbrxModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [( "ws" if weight_name in ["w1", "v1"] else "w2s", f"experts.mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, weight_name) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)