worker_manager.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from typing import Any, Set, Type
  2. import torch
  3. from aphrodite.adapter_commons.utils import (add_adapter_worker,
  4. apply_adapters_worker,
  5. list_adapters_worker,
  6. set_active_adapters_worker)
  7. from aphrodite.adapter_commons.worker_manager import AbstractWorkerManager
  8. from aphrodite.common.config import ControlVectorConfig
  9. from aphrodite.control_vectors.models import (
  10. ControlVectorModel, ControlVectorModelManager,
  11. LRUCacheControlVectorModelManager, create_cv_manager)
  12. from aphrodite.control_vectors.request import ControlVectorRequest
  13. class WorkerControlVectorManager(AbstractWorkerManager):
  14. """WorkerControlVectorManager that manages
  15. control vector models on the worker side.
  16. Every request, the requested control vectors will be
  17. loaded (unless they are already loaded),
  18. and every other control vector will be unloaded."""
  19. _manager_cls: Type[ControlVectorModelManager] = ControlVectorModelManager
  20. def __init__(
  21. self,
  22. device: torch.device,
  23. control_vector_config: ControlVectorConfig,
  24. control_vector_model_cls: Type[ControlVectorModel] = ControlVectorModel
  25. ):
  26. self._adapter_manager: ControlVectorModelManager
  27. self._control_vector_model_cls = control_vector_model_cls
  28. self.control_vector_config = control_vector_config
  29. super().__init__(device)
  30. @property
  31. def is_enabled(self) -> bool:
  32. return True
  33. def create_control_vector_manager(
  34. self,
  35. model: torch.nn.Module,
  36. ) -> Any:
  37. control_vector_manager = create_cv_manager(
  38. model,
  39. control_vector_config=self.control_vector_config,
  40. control_vector_manager_cls=self._manager_cls,
  41. )
  42. self._adapter_manager = control_vector_manager
  43. return control_vector_manager.model
  44. def _load_adapter(
  45. self, control_vector_request: ControlVectorRequest
  46. ) -> ControlVectorModel:
  47. try:
  48. control_vector = (
  49. self._control_vector_model_cls.from_local_checkpoint(
  50. control_vector_request.control_vector_local_path,
  51. control_vector_id=control_vector_request.control_vector_id,
  52. config=self.control_vector_config,
  53. device=str(self.device),
  54. scale_factor=control_vector_request.scale_factor))
  55. except Exception as e:
  56. raise RuntimeError(
  57. f"Loading control vector "
  58. f"{control_vector_request.control_vector_local_path}"
  59. f" failed") from e
  60. return control_vector
  61. def add_dummy_control_vector(
  62. self, control_vector_request: ControlVectorRequest) -> bool:
  63. return True
  64. def pin_adapter(self, adapter_id: int) -> bool:
  65. return self._adapter_manager.pin_adapter(adapter_id)
  66. def set_active_adapters(self, requests: Set[Any]) -> None:
  67. assert len(
  68. requests
  69. ) <= 1, "Currently, we do not support more than 1 control vectors at one time"
  70. if requests:
  71. mapping = [req.adapter_id for req in requests][0]
  72. else:
  73. mapping = None
  74. set_active_adapters_worker(requests, mapping, self._apply_adapters,
  75. self._adapter_manager.set_adapter_mapping)
  76. def add_adapter(self, adapter_request: Any) -> bool:
  77. return add_adapter_worker(adapter_request, self.list_adapters,
  78. self._load_adapter,
  79. self._adapter_manager.add_adapter,
  80. self._adapter_manager.activate_adapter)
  81. def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
  82. apply_adapters_worker(adapter_requests, self.list_adapters,
  83. self._adapter_manager.adapter_slots,
  84. self.remove_adapter, self.add_adapter)
  85. def remove_adapter(self, adapter_id: int) -> bool:
  86. return self._adapter_manager.remove_adapter(adapter_id)
  87. def remove_all_adapters(self):
  88. self._adapter_manager.remove_all_adapters()
  89. def list_adapters(self) -> Set[int]:
  90. return list_adapters_worker(self._adapter_manager.list_adapters)
  91. class LRUCacheWorkerControlVectorManager(WorkerControlVectorManager):
  92. """WorkerControlVectorManager that manages
  93. control vector models on the worker side.
  94. Uses an LRU Cache. Every request, the requested
  95. control vectors will be loaded (unless they are already loaded)
  96. and least recently used control vectors will
  97. be unloaded if the cache is above capacity."""
  98. _control_vector_manager_cls: Type[
  99. LRUCacheControlVectorModelManager] = LRUCacheControlVectorModelManager
  100. def create_control_vector_manager(
  101. self,
  102. model: torch.nn.Module,
  103. ) -> Any:
  104. control_vector_manager = create_cv_manager(
  105. model,
  106. control_vector_config=self.control_vector_config,
  107. control_vector_manager_cls=self._control_vector_manager_cls)
  108. self._adapter_manager: LRUCacheControlVectorModelManager = (
  109. control_vector_manager)
  110. return control_vector_manager.model
  111. def _apply_adapters(
  112. self, control_vector_requests: Set[ControlVectorRequest]) -> None:
  113. control_vectors_map = {
  114. control_vector_request.control_vector_id: control_vector_request
  115. for control_vector_request in control_vector_requests
  116. if control_vector_request
  117. }
  118. if len(control_vectors_map) > self._adapter_manager.adapter_slots:
  119. raise RuntimeError(f"Number of requested control vectors "
  120. f"({len(control_vectors_map)}) is greater "
  121. "than the number of GPU control vector slots "
  122. f"({self._adapter_manager.adapter_slots}).")
  123. for control_vector in control_vectors_map.values():
  124. self.add_adapter(control_vector)
  125. def add_adapter(self,
  126. control_vector_request: ControlVectorRequest) -> bool:
  127. if control_vector_request.control_vector_id not in self.list_adapters(
  128. ):
  129. # Remove before we load the new control vector to save memory
  130. if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
  131. self._adapter_manager.remove_oldest_adapter()
  132. control_vector = self._load_adapter(control_vector_request)
  133. loaded = self._adapter_manager.add_adapter(control_vector)
  134. else:
  135. loaded = self._adapter_manager.get_adapter(
  136. control_vector_request.adapter_id)
  137. self._adapter_manager.activate_adapter(
  138. control_vector_request.control_vector_id)
  139. return loaded