from dataclasses import dataclass from typing import Optional import torch from torch import nn from aphrodite.adapter_commons.layers import AdapterMapping from aphrodite.common.config import PromptAdapterConfig from aphrodite.modeling.layers.vocab_parallel_embedding import \ VocabParallelEmbedding @dataclass class PromptAdapterMapping(AdapterMapping): pass class VocabParallelEmbeddingWithPromptAdapter(nn.Module): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer self.emb_layer = self.base_layer if 'LoRA' in base_layer.__class__.__name__: self.emb_layer = self.base_layer.base_layer def create_prompt_adapter_weights( self, prompt_adapter_config: PromptAdapterConfig): self.embeddings_tensors = torch.zeros( ( prompt_adapter_config.max_prompt_adapters, prompt_adapter_config.max_prompt_adapter_token, self.emb_layer.embedding_dim, ), dtype=self.emb_layer.weight.dtype, device=self.emb_layer.weight.device, ) self.adapter_lengths = torch.zeros( prompt_adapter_config.max_prompt_adapters, dtype=torch.long, device=self.emb_layer.weight.device) self.indices_gpu: torch.Tensor self.embedding_indices_gpu: torch.Tensor def reset_prompt_adapter(self, index: int): self.embeddings_tensors[index] = 0 def set_prompt_adapter( self, index: int, adapter_model: Optional[torch.Tensor], ): self.reset_prompt_adapter(index) if adapter_model is not None: length = adapter_model.shape[0] self.embeddings_tensors[index, :length] = adapter_model self.adapter_lengths[index] = length def set_mapping( self, prompt_indices: torch.Tensor, prompt_embedding_indices: torch.Tensor, ): self.indices_gpu = prompt_indices.to( device=self.emb_layer.weight.device) self.embedding_indices_gpu = prompt_embedding_indices.to( device=self.emb_layer.weight.device) def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = self.base_layer(x) if self.embedding_indices_gpu.ndim > 1: valid_mask = self.indices_gpu != -1 gathered_embeddings = self.embeddings_tensors[ self.embedding_indices_gpu[:, 0], self.embedding_indices_gpu[:, 1]] # Update hidden states hidden_states[valid_mask] = gathered_embeddings return hidden_states