models.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
  3. from torch import nn
  4. from loguru import logger
  5. from aphrodite.common.utils import LRUCache
  6. class AdapterModel(ABC):
  7. def __init__(self, model_id=None):
  8. self.id = model_id
  9. @abstractmethod
  10. def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
  11. # Common initialization code
  12. # Load weights or embeddings from local checkpoint
  13. raise NotImplementedError("Subclasses must implement this method.")
  14. T = TypeVar('T')
  15. class AdapterLRUCache(LRUCache[T]):
  16. def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
  17. None]):
  18. super().__init__(capacity)
  19. self.deactivate_fn = deactivate_fn
  20. def _on_remove(self, key: Hashable, value: T):
  21. logger.debug(f"Removing adapter int id: {key}")
  22. self.deactivate_fn(key)
  23. return super()._on_remove(key, value)
  24. class AdapterModelManager(ABC):
  25. def __init__(
  26. self,
  27. model: nn.Module,
  28. ):
  29. """Create a AdapterModelManager and adapter for a given model.
  30. Args:
  31. model: the model to be adapted.
  32. """
  33. self.model: nn.Module = model
  34. self._registered_adapters: Dict[int, Any] = {}
  35. # Dict instead of a Set for compatibility with LRUCache.
  36. self._active_adapters: Dict[int, None] = {}
  37. self.adapter_type = 'Adapter'
  38. self._last_mapping = None
  39. def __len__(self) -> int:
  40. return len(self._registered_adapters)
  41. @property
  42. @abstractmethod
  43. def adapter_slots(self):
  44. ...
  45. @property
  46. @abstractmethod
  47. def capacity(self):
  48. ...
  49. @abstractmethod
  50. def activate_adapter(self, adapter_id: int) -> bool:
  51. ...
  52. @abstractmethod
  53. def deactivate_adapter(self, adapter_id: int) -> bool:
  54. ...
  55. @abstractmethod
  56. def add_adapter(self, adapter: Any) -> bool:
  57. ...
  58. @abstractmethod
  59. def set_adapter_mapping(self, mapping: Any) -> None:
  60. ...
  61. @abstractmethod
  62. def remove_adapter(self, adapter_id: int) -> bool:
  63. ...
  64. @abstractmethod
  65. def remove_all_adapters(self):
  66. ...
  67. @abstractmethod
  68. def get_adapter(self, adapter_id: int) -> Optional[Any]:
  69. ...
  70. @abstractmethod
  71. def list_adapters(self) -> Dict[int, Any]:
  72. ...
  73. @abstractmethod
  74. def pin_adapter(self, adapter_id: int) -> bool:
  75. ...