worker_manager.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from typing import Any, Optional, Set, Type
  2. import torch
  3. from aphrodite.adapter_commons.utils import (add_adapter_worker,
  4. apply_adapters_worker,
  5. list_adapters_worker,
  6. set_active_adapters_worker)
  7. from aphrodite.adapter_commons.worker_manager import AbstractWorkerManager
  8. from aphrodite.common.config import PromptAdapterConfig
  9. from aphrodite.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
  10. PromptAdapterModel,
  11. PromptAdapterModelManager,
  12. create_prompt_adapter_manager)
  13. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  14. class WorkerPromptAdapterManager(AbstractWorkerManager):
  15. """WorkerPromptAdapterManager that manages
  16. prompt_adapter models on the worker side.
  17. Every request, the requested prompt_adapters will be
  18. loaded (unless they are already loaded),
  19. and every other prompt_adapter will be unloaded."""
  20. _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
  21. def __init__(
  22. self,
  23. max_num_seqs: int,
  24. max_num_batched_tokens: int,
  25. device: torch.device,
  26. prompt_adapter_config: PromptAdapterConfig,
  27. prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
  28. ):
  29. self._adapter_manager: PromptAdapterModelManager
  30. self.max_num_seqs = max_num_seqs
  31. self.max_num_batched_tokens = max_num_batched_tokens
  32. self._prompt_adapter_model_cls = prompt_adapter_model_cls
  33. self.prompt_adapter_config = prompt_adapter_config
  34. super().__init__(device)
  35. @property
  36. def is_enabled(self) -> bool:
  37. return True
  38. def create_prompt_adapter_manager(
  39. self,
  40. model: torch.nn.Module,
  41. ) -> Any:
  42. prompt_adapter_manager = create_prompt_adapter_manager(
  43. model,
  44. max_num_seqs=self.max_num_seqs,
  45. max_num_batched_tokens=self.max_num_batched_tokens,
  46. prompt_adapter_config=self.prompt_adapter_config,
  47. prompt_adapter_manager_cls=self._manager_cls,
  48. )
  49. self._adapter_manager = prompt_adapter_manager
  50. return prompt_adapter_manager.model
  51. def _load_adapter(
  52. self, prompt_adapter_request: PromptAdapterRequest
  53. ) -> PromptAdapterModel:
  54. try:
  55. prompt_adapter = (
  56. self._prompt_adapter_model_cls.from_local_checkpoint(
  57. prompt_adapter_request.prompt_adapter_local_path,
  58. prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
  59. num_virtual_tokens=prompt_adapter_request.
  60. prompt_adapter_num_virtual_tokens,
  61. config=self.prompt_adapter_config,
  62. device=str(self.device),
  63. ))
  64. except Exception as e:
  65. raise RuntimeError(
  66. f"Loading prompt_adapter "
  67. f"{prompt_adapter_request.prompt_adapter_local_path}"
  68. f" failed") from e
  69. return prompt_adapter
  70. def add_dummy_prompt_adapter(
  71. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  72. return True
  73. def pin_adapter(self, adapter_id: int) -> bool:
  74. return self._adapter_manager.pin_adapter(adapter_id)
  75. def set_active_adapters(self, requests: Set[Any],
  76. mapping: Optional[Any]) -> None:
  77. set_active_adapters_worker(requests, mapping, self._apply_adapters,
  78. self._adapter_manager.set_adapter_mapping)
  79. def add_adapter(self, adapter_request: Any) -> bool:
  80. return add_adapter_worker(adapter_request, self.list_adapters,
  81. self._load_adapter,
  82. self._adapter_manager.add_adapter,
  83. self._adapter_manager.activate_adapter)
  84. def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
  85. apply_adapters_worker(adapter_requests, self.list_adapters,
  86. self._adapter_manager.adapter_slots,
  87. self.remove_adapter, self.add_adapter)
  88. def remove_adapter(self, adapter_id: int) -> bool:
  89. return self._adapter_manager.remove_adapter(adapter_id)
  90. def remove_all_adapters(self):
  91. self._adapter_manager.remove_all_adapters()
  92. def list_adapters(self) -> Set[int]:
  93. return list_adapters_worker(self._adapter_manager.list_adapters)
  94. class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
  95. """WorkerPromptAdapterManager that manages
  96. prompt_adapter models on the worker side.
  97. Uses an LRU Cache. Every request, the requested
  98. prompt_adapters will be loaded (unless they are already loaded)
  99. and least recently used prompt_adapters will
  100. be unloaded if the cache is above capacity."""
  101. _prompt_adapter_manager_cls: Type[
  102. LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
  103. def create_prompt_adapter_manager(
  104. self,
  105. model: torch.nn.Module,
  106. ) -> Any:
  107. prompt_adapter_manager = create_prompt_adapter_manager(
  108. model,
  109. max_num_seqs=self.max_num_seqs,
  110. max_num_batched_tokens=self.max_num_batched_tokens,
  111. prompt_adapter_config=self.prompt_adapter_config,
  112. prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
  113. self._adapter_manager: LRUCachePromptAdapterModelManager = (
  114. prompt_adapter_manager)
  115. return prompt_adapter_manager.model
  116. def _apply_adapters(
  117. self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
  118. prompt_adapters_map = {
  119. prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
  120. for prompt_adapter_request in prompt_adapter_requests
  121. if prompt_adapter_request
  122. }
  123. if len(prompt_adapters_map
  124. ) > self._adapter_manager.prompt_adapter_slots:
  125. raise RuntimeError(
  126. f"Number of requested prompt_adapters "
  127. f"({len(prompt_adapters_map)}) is greater "
  128. "than the number of GPU prompt_adapter slots "
  129. f"({self._adapter_manager.prompt_adapter_slots}).")
  130. for prompt_adapter in prompt_adapters_map.values():
  131. self.add_adapter(prompt_adapter)
  132. def add_adapter(self,
  133. prompt_adapter_request: PromptAdapterRequest) -> bool:
  134. if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
  135. ):
  136. # Remove before we load the new prompt_adapter to save memory
  137. if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
  138. self._adapter_manager.remove_oldest_adapter()
  139. prompt_adapter = self._load_adapter(prompt_adapter_request)
  140. loaded = self._adapter_manager.add_adapter(prompt_adapter)
  141. else:
  142. # If the prompt_adapter is already loaded, just touch it to
  143. # update its position in the caches
  144. loaded = self._adapter_manager.get_adapter(
  145. prompt_adapter_request.prompt_adapter_id) is not None
  146. self._adapter_manager.activate_adapter(
  147. prompt_adapter_request.prompt_adapter_id)
  148. return loaded