1
0

__init__.py 1.4 KB

12345678910111213141516171819202122232425262728293031
  1. from typing import Optional
  2. from torch import nn
  3. from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
  4. ModelConfig, ParallelConfig,
  5. SchedulerConfig, VisionLanguageConfig)
  6. from aphrodite.modeling.model_loader.loader import (BaseModelLoader,
  7. get_model_loader)
  8. from aphrodite.modeling.model_loader.utils import (get_architecture_class_name,
  9. get_model_architecture)
  10. def get_model(
  11. *, model_config: ModelConfig, load_config: LoadConfig,
  12. device_config: DeviceConfig, parallel_config: ParallelConfig,
  13. scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
  14. vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
  15. loader = get_model_loader(load_config)
  16. return loader.load_model(model_config=model_config,
  17. device_config=device_config,
  18. lora_config=lora_config,
  19. vision_language_config=vision_language_config,
  20. parallel_config=parallel_config,
  21. scheduler_config=scheduler_config)
  22. __all__ = [
  23. "get_model", "get_model_loader", "BaseModelLoader",
  24. "get_architecture_class_name", "get_model_architecture"
  25. ]