layers.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. from torch import nn
  5. from aphrodite.adapter_commons.layers import AdapterMapping
  6. from aphrodite.common.config import PromptAdapterConfig
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  8. VocabParallelEmbedding)
  9. @dataclass
  10. class PromptAdapterMapping(AdapterMapping):
  11. pass
  12. class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
  13. def __init__(self, base_layer: VocabParallelEmbedding) -> None:
  14. super().__init__()
  15. self.base_layer = base_layer
  16. self.emb_layer = self.base_layer
  17. if 'LoRA' in base_layer.__class__.__name__:
  18. self.emb_layer = self.base_layer.base_layer
  19. def create_prompt_adapter_weights(
  20. self, prompt_adapter_config: PromptAdapterConfig):
  21. self.embeddings_tensors = torch.zeros(
  22. (
  23. prompt_adapter_config.max_prompt_adapters,
  24. prompt_adapter_config.max_prompt_adapter_token,
  25. self.emb_layer.embedding_dim,
  26. ),
  27. dtype=self.emb_layer.weight.dtype,
  28. device=self.emb_layer.weight.device,
  29. )
  30. self.adapter_lengths = torch.zeros(
  31. prompt_adapter_config.max_prompt_adapters,
  32. dtype=torch.long,
  33. device=self.emb_layer.weight.device)
  34. self.indices_gpu: torch.Tensor
  35. self.embedding_indices_gpu: torch.Tensor
  36. def reset_prompt_adapter(self, index: int):
  37. self.embeddings_tensors[index] = 0
  38. def set_prompt_adapter(
  39. self,
  40. index: int,
  41. adapter_model: Optional[torch.Tensor],
  42. ):
  43. self.reset_prompt_adapter(index)
  44. if adapter_model is not None:
  45. length = adapter_model.shape[0]
  46. self.embeddings_tensors[index, :length] = adapter_model
  47. self.adapter_lengths[index] = length
  48. def set_mapping(
  49. self,
  50. prompt_indices: torch.Tensor,
  51. prompt_embedding_indices: torch.Tensor,
  52. ):
  53. self.indices_gpu = prompt_indices.to(
  54. device=self.emb_layer.weight.device)
  55. self.embedding_indices_gpu = prompt_embedding_indices.to(
  56. device=self.emb_layer.weight.device)
  57. def forward(self, x: torch.Tensor) -> torch.Tensor:
  58. hidden_states = self.base_layer(x)
  59. if self.embedding_indices_gpu.ndim > 1:
  60. valid_mask = self.indices_gpu != -1
  61. gathered_embeddings = self.embeddings_tensors[
  62. self.embedding_indices_gpu[:, 0],
  63. self.embedding_indices_gpu[:, 1]]
  64. # Update hidden states
  65. hidden_states[valid_mask] = gathered_embeddings
  66. return hidden_states