utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from typing import Any, Callable, Dict, Optional, Set
  2. ## model functions
  3. def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
  4. deactivate_func: Callable) -> bool:
  5. if adapter_id in active_adapters:
  6. deactivate_func(adapter_id)
  7. active_adapters.pop(adapter_id)
  8. return True
  9. return False
  10. def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
  11. capacity: int, add_func: Callable) -> bool:
  12. if adapter.id not in registered_adapters:
  13. if len(registered_adapters) >= capacity:
  14. raise RuntimeError('No free adapter slots.')
  15. add_func(adapter)
  16. registered_adapters[adapter.id] = adapter
  17. return True
  18. return False
  19. def set_adapter_mapping(mapping: Any, last_mapping: Any,
  20. set_mapping_func: Callable) -> Any:
  21. if last_mapping != mapping:
  22. set_mapping_func(mapping)
  23. return mapping
  24. return last_mapping
  25. def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
  26. deactivate_func: Callable) -> bool:
  27. deactivate_func(adapter_id)
  28. return bool(registered_adapters.pop(adapter_id, None))
  29. def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
  30. return dict(registered_adapters)
  31. def get_adapter(adapter_id: int,
  32. registered_adapters: Dict[int, Any]) -> Optional[Any]:
  33. return registered_adapters.get(adapter_id, None)
  34. ## worker functions
  35. def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
  36. apply_adapters_func,
  37. set_adapter_mapping_func) -> None:
  38. apply_adapters_func(requests)
  39. set_adapter_mapping_func(mapping)
  40. def add_adapter_worker(adapter_request: Any, list_adapters_func,
  41. load_adapter_func, add_adapter_func,
  42. activate_adapter_func) -> bool:
  43. if adapter_request.adapter_id in list_adapters_func():
  44. return False
  45. loaded_adapter = load_adapter_func(adapter_request)
  46. loaded = add_adapter_func(loaded_adapter)
  47. activate_adapter_func(loaded_adapter.id)
  48. return loaded
  49. def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
  50. adapter_slots: int, remove_adapter_func,
  51. add_adapter_func) -> None:
  52. models_that_exist = list_adapters_func()
  53. models_map = {
  54. adapter_request.adapter_id: adapter_request
  55. for adapter_request in adapter_requests if adapter_request
  56. }
  57. if len(models_map) > adapter_slots:
  58. raise RuntimeError(
  59. f"Number of requested models ({len(models_map)}) is greater "
  60. f"than the number of GPU model slots "
  61. f"({adapter_slots}).")
  62. new_models = set(models_map)
  63. models_to_add = new_models - models_that_exist
  64. models_to_remove = models_that_exist - new_models
  65. for adapter_id in models_to_remove:
  66. remove_adapter_func(adapter_id)
  67. for adapter_id in models_to_add:
  68. add_adapter_func(models_map[adapter_id])
  69. def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
  70. return set(adapter_manager_list_adapters_func())