12345678910111213141516171819202122232425262728293031 |
- from typing import Optional
- from torch import nn
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig)
- from aphrodite.modeling.model_loader.loader import (BaseModelLoader,
- get_model_loader)
- from aphrodite.modeling.model_loader.utils import (get_architecture_class_name,
- get_model_architecture)
- def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
- device_config: DeviceConfig, parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- lora_config: Optional[LoRAConfig],
- cache_config: CacheConfig) -> nn.Module:
- loader = get_model_loader(load_config)
- return loader.load_model(model_config=model_config,
- device_config=device_config,
- lora_config=lora_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- cache_config=cache_config)
- __all__ = [
- "get_model", "get_model_loader", "BaseModelLoader",
- "get_architecture_class_name", "get_model_architecture"
- ]
|