__init__.py 1.5 KB

123456789101112131415161718192021222324252627282930313233
  1. from typing import Optional
  2. from torch import nn
  3. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  4. LoRAConfig, ModelConfig, MultiModalConfig,
  5. ParallelConfig, SchedulerConfig)
  6. from aphrodite.modeling.model_loader.loader import (BaseModelLoader,
  7. get_model_loader)
  8. from aphrodite.modeling.model_loader.utils import (get_architecture_class_name,
  9. get_model_architecture)
  10. def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
  11. device_config: DeviceConfig, parallel_config: ParallelConfig,
  12. scheduler_config: SchedulerConfig,
  13. lora_config: Optional[LoRAConfig],
  14. multimodal_config: Optional[MultiModalConfig],
  15. cache_config: CacheConfig) -> nn.Module:
  16. loader = get_model_loader(load_config)
  17. return loader.load_model(model_config=model_config,
  18. device_config=device_config,
  19. lora_config=lora_config,
  20. multimodal_config=multimodal_config,
  21. parallel_config=parallel_config,
  22. scheduler_config=scheduler_config,
  23. cache_config=cache_config)
  24. __all__ = [
  25. "get_model", "get_model_loader", "BaseModelLoader",
  26. "get_architecture_class_name", "get_model_architecture"
  27. ]