from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) from loguru import logger from typing_extensions import TypeIs from aphrodite.common.config import (LoRAConfig, MultiModalConfig, SchedulerConfig) @runtime_checkable class SupportsMultiModal(Protocol): """The interface required for all multi-modal models.""" supports_multimodal: ClassVar[Literal[True]] = True """ A flag that indicates this model supports multi-modal inputs. Note: There is no need to redefine this flag if this class is in the MRO of your model class. """ def __init__(self, *, multimodal_config: MultiModalConfig) -> None: ... # We can't use runtime_checkable with ClassVar for issubclass checks # so we need to treat the class as an instance and use isinstance instead @runtime_checkable class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] def __call__(self, *, multimodal_config: MultiModalConfig) -> None: ... @overload def supports_multimodal( model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]: ... @overload def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ... def supports_multimodal( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: if isinstance(model, type): return isinstance(model, _SupportsMultiModalType) return isinstance(model, SupportsMultiModal) @runtime_checkable class SupportsLoRA(Protocol): """The interface required for all models that support LoRA.""" supports_lora: ClassVar[Literal[True]] = True """ A flag that indicates this model supports LoRA. Note: There is no need to redefine this flag if this class is in the MRO of your model class. """ packed_modules_mapping: ClassVar[Dict[str, List[str]]] supported_lora_modules: ClassVar[List[str]] embedding_modules: ClassVar[Dict[str, str]] embedding_padding_modules: ClassVar[List[str]] # lora_config is None when LoRA is not enabled def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: ... # We can't use runtime_checkable with ClassVar for issubclass checks # so we need to treat the class as an instance and use isinstance instead @runtime_checkable class _SupportsLoRAType(Protocol): supports_lora: Literal[True] packed_modules_mapping: Dict[str, List[str]] supported_lora_modules: List[str] embedding_modules: Dict[str, str] embedding_padding_modules: List[str] def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: ... @overload def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: ... @overload def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ... def supports_lora( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: result = _supports_lora(model) if not result: lora_attrs = ( "packed_modules_mapping", "supported_lora_modules", "embedding_modules", "embedding_padding_modules", ) missing_attrs = tuple(attr for attr in lora_attrs if not hasattr(model, attr)) if getattr(model, "supports_lora", False): if missing_attrs: logger.warning( f"The model ({model}) sets `supports_lora=True`, " "but is missing LoRA-specific attributes: " f"{missing_attrs}", ) else: if not missing_attrs: logger.warning( f"The model ({model}) contains all LoRA-specific " "attributes, but does not set `supports_lora=True`.") return result def _supports_lora( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) return isinstance(model, SupportsLoRA) @runtime_checkable class HasInnerState(Protocol): """The interface required for all models that has inner state.""" has_inner_state: ClassVar[Literal[True]] = True """ A flag that indicates this model has inner state. Models that has inner state usually need access to the scheduler_config for max_num_seqs ,etc... (Currently only used by Jamba and Mamba) """ def __init__(self, *, scheduler_config: Optional[SchedulerConfig] = None) -> None: ... @runtime_checkable class _HasInnerStateType(Protocol): has_inner_state: ClassVar[Literal[True]] def __init__(self, *, scheduler_config: Optional[SchedulerConfig] = None) -> None: ... @overload def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @overload def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: ... def has_inner_state( model: Union[Type[object], object] ) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: if isinstance(model, type): return isinstance(model, _HasInnerStateType) return isinstance(model, HasInnerState)