123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- from typing import Iterable, List, Optional, Tuple
- import torch
- from torch import nn
- from aphrodite.attention import AttentionMetadata
- from aphrodite.common.sequence import PoolerOutput
- from aphrodite.modeling.layers.pooler import Pooler, PoolingType
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.models.llama import LlamaModel
- from aphrodite.modeling.pooling_metadata import PoolingMetadata
- class LlamaEmbeddingModel(nn.Module):
- """A model that uses Llama with additional embedding functionalities.
- This class encapsulates the LlamaModel and provides an interface for
- embedding operations and customized pooling functions.
- Attributes:
- model: An instance of LlamaModel used for forward operations.
- _pooler: An instance of Pooler used for pooling operations.
- """
- def __init__(
- self,
- **kwargs,
- ) -> None:
- super().__init__()
- self.model = LlamaModel(**kwargs)
- self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- inputs_embeds: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- return self.model.forward(input_ids, positions, kv_caches,
- attn_metadata, inputs_embeds)
- def pooler(
- self,
- hidden_states: torch.Tensor,
- pooling_metadata: PoolingMetadata,
- ) -> Optional[PoolerOutput]:
- return self._pooler(hidden_states, pooling_metadata)
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- params_dict = dict(self.model.named_parameters())
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if ("rotary_emb.cos_cached" in name
- or "rotary_emb.sin_cached" in name):
- # Models trained using ColossalAI may include these tensors in
- # the checkpoint. Skip them.
- continue
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|