__init__.py 1.3 KB

12345678910111213141516171819202122232425262728293031
  1. from typing import Optional
  2. from torch import nn
  3. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  4. LoRAConfig, ModelConfig, ParallelConfig,
  5. SchedulerConfig)
  6. from aphrodite.modeling.model_loader.loader import (BaseModelLoader,
  7. get_model_loader)
  8. from aphrodite.modeling.model_loader.utils import (get_architecture_class_name,
  9. get_model_architecture)
  10. def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
  11. device_config: DeviceConfig, parallel_config: ParallelConfig,
  12. scheduler_config: SchedulerConfig,
  13. lora_config: Optional[LoRAConfig],
  14. cache_config: CacheConfig) -> nn.Module:
  15. loader = get_model_loader(load_config)
  16. return loader.load_model(model_config=model_config,
  17. device_config=device_config,
  18. lora_config=lora_config,
  19. parallel_config=parallel_config,
  20. scheduler_config=scheduler_config,
  21. cache_config=cache_config)
  22. __all__ = [
  23. "get_model", "get_model_loader", "BaseModelLoader",
  24. "get_architecture_class_name", "get_model_architecture"
  25. ]