123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
- Union, overload, runtime_checkable)
- from loguru import logger
- from typing_extensions import TypeIs
- from aphrodite.common.config import (LoRAConfig, MultiModalConfig,
- SchedulerConfig)
- @runtime_checkable
- class SupportsMultiModal(Protocol):
- """The interface required for all multi-modal models."""
- supports_multimodal: ClassVar[Literal[True]] = True
- """
- A flag that indicates this model supports multi-modal inputs.
- Note:
- There is no need to redefine this flag if this class is in the
- MRO of your model class.
- """
- def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
- ...
- # We can't use runtime_checkable with ClassVar for issubclass checks
- # so we need to treat the class as an instance and use isinstance instead
- @runtime_checkable
- class _SupportsMultiModalType(Protocol):
- supports_multimodal: Literal[True]
- def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
- ...
- @overload
- def supports_multimodal(
- model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
- ...
- @overload
- def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
- ...
- def supports_multimodal(
- model: Union[Type[object], object],
- ) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
- if isinstance(model, type):
- return isinstance(model, _SupportsMultiModalType)
- return isinstance(model, SupportsMultiModal)
- @runtime_checkable
- class SupportsLoRA(Protocol):
- """The interface required for all models that support LoRA."""
- supports_lora: ClassVar[Literal[True]] = True
- """
- A flag that indicates this model supports LoRA.
- Note:
- There is no need to redefine this flag if this class is in the
- MRO of your model class.
- """
- packed_modules_mapping: ClassVar[Dict[str, List[str]]]
- supported_lora_modules: ClassVar[List[str]]
- embedding_modules: ClassVar[Dict[str, str]]
- embedding_padding_modules: ClassVar[List[str]]
- # lora_config is None when LoRA is not enabled
- def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
- ...
- # We can't use runtime_checkable with ClassVar for issubclass checks
- # so we need to treat the class as an instance and use isinstance instead
- @runtime_checkable
- class _SupportsLoRAType(Protocol):
- supports_lora: Literal[True]
- packed_modules_mapping: Dict[str, List[str]]
- supported_lora_modules: List[str]
- embedding_modules: Dict[str, str]
- embedding_padding_modules: List[str]
- def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
- ...
- @overload
- def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
- ...
- @overload
- def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
- ...
- def supports_lora(
- model: Union[Type[object], object],
- ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
- result = _supports_lora(model)
- if not result:
- lora_attrs = (
- "packed_modules_mapping",
- "supported_lora_modules",
- "embedding_modules",
- "embedding_padding_modules",
- )
- missing_attrs = tuple(attr for attr in lora_attrs
- if not hasattr(model, attr))
- if getattr(model, "supports_lora", False):
- if missing_attrs:
- logger.warning(
- f"The model ({model}) sets `supports_lora=True`, "
- "but is missing LoRA-specific attributes: "
- f"{missing_attrs}", )
- else:
- if not missing_attrs:
- logger.warning(
- f"The model ({model}) contains all LoRA-specific "
- "attributes, but does not set `supports_lora=True`.")
- return result
- def _supports_lora(
- model: Union[Type[object], object],
- ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
- if isinstance(model, type):
- return isinstance(model, _SupportsLoRAType)
- return isinstance(model, SupportsLoRA)
- @runtime_checkable
- class HasInnerState(Protocol):
- """The interface required for all models that has inner state."""
- has_inner_state: ClassVar[Literal[True]] = True
- """
- A flag that indicates this model has inner state.
- Models that has inner state usually need access to the scheduler_config
- for max_num_seqs ,etc... (Currently only used by Jamba and Mamba)
- """
- def __init__(self,
- *,
- scheduler_config: Optional[SchedulerConfig] = None) -> None:
- ...
- @runtime_checkable
- class _HasInnerStateType(Protocol):
- has_inner_state: ClassVar[Literal[True]]
- def __init__(self,
- *,
- scheduler_config: Optional[SchedulerConfig] = None) -> None:
- ...
- @overload
- def has_inner_state(model: object) -> TypeIs[HasInnerState]:
- ...
- @overload
- def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
- ...
- def has_inner_state(
- model: Union[Type[object], object]
- ) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
- if isinstance(model, type):
- return isinstance(model, _HasInnerStateType)
- return isinstance(model, HasInnerState)
|