from typing import Iterable, List, Optional, Tuple import torch from torch import nn from aphrodite.attention import AttentionMetadata from aphrodite.common.sequence import PoolerOutput from aphrodite.modeling.layers.pooler import Pooler, PoolingType from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.models.llama import LlamaModel from aphrodite.modeling.pooling_metadata import PoolingMetadata class LlamaEmbeddingModel(nn.Module): """A model that uses Llama with additional embedding functionalities. This class encapsulates the LlamaModel and provides an interface for embedding operations and customized pooling functions. Attributes: model: An instance of LlamaModel used for forward operations. _pooler: An instance of Pooler used for pooling operations. """ def __init__( self, **kwargs, ) -> None: super().__init__() self.model = LlamaModel(**kwargs) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model.forward(input_ids, positions, kv_caches, attn_metadata, inputs_embeds) def pooler( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.model.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)