interfaces.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
  2. Union, overload, runtime_checkable)
  3. from loguru import logger
  4. from typing_extensions import TypeIs
  5. from aphrodite.common.config import (LoRAConfig, MultiModalConfig,
  6. SchedulerConfig)
  7. @runtime_checkable
  8. class SupportsMultiModal(Protocol):
  9. """
  10. The interface required for all multimodal (vision or audio) language
  11. models.
  12. """
  13. supports_multimodal: ClassVar[Literal[True]] = True
  14. """
  15. A flag that indicates this model supports multimodal inputs.
  16. Note:
  17. There is no need to redefine this flag if this class is in the
  18. MRO of your model class.
  19. """
  20. def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
  21. ...
  22. # We can't use runtime_checkable with ClassVar for issubclass checks
  23. # so we need to treat the class as an instance and use isinstance instead
  24. @runtime_checkable
  25. class _SupportsMultiModalType(Protocol):
  26. supports_multimodal: Literal[True]
  27. def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
  28. ...
  29. @overload
  30. def supports_multimodal(
  31. model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
  32. ...
  33. @overload
  34. def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
  35. ...
  36. def supports_multimodal(
  37. model: Union[Type[object], object],
  38. ) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
  39. if isinstance(model, type):
  40. return isinstance(model, _SupportsMultiModalType)
  41. return isinstance(model, SupportsMultiModal)
  42. @runtime_checkable
  43. class SupportsLoRA(Protocol):
  44. """The interface required for all models that support LoRA."""
  45. supports_lora: ClassVar[Literal[True]] = True
  46. """
  47. A flag that indicates this model supports LoRA.
  48. Note:
  49. There is no need to redefine this flag if this class is in the
  50. MRO of your model class.
  51. """
  52. packed_modules_mapping: ClassVar[Dict[str, List[str]]]
  53. supported_lora_modules: ClassVar[List[str]]
  54. embedding_modules: ClassVar[Dict[str, str]]
  55. embedding_padding_modules: ClassVar[List[str]]
  56. # lora_config is None when LoRA is not enabled
  57. def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  58. ...
  59. # We can't use runtime_checkable with ClassVar for issubclass checks
  60. # so we need to treat the class as an instance and use isinstance instead
  61. @runtime_checkable
  62. class _SupportsLoRAType(Protocol):
  63. supports_lora: Literal[True]
  64. packed_modules_mapping: Dict[str, List[str]]
  65. supported_lora_modules: List[str]
  66. embedding_modules: Dict[str, str]
  67. embedding_padding_modules: List[str]
  68. def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  69. ...
  70. @overload
  71. def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
  72. ...
  73. @overload
  74. def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
  75. ...
  76. def supports_lora(
  77. model: Union[Type[object], object],
  78. ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
  79. result = _supports_lora(model)
  80. if not result:
  81. lora_attrs = (
  82. "packed_modules_mapping",
  83. "supported_lora_modules",
  84. "embedding_modules",
  85. "embedding_padding_modules",
  86. )
  87. missing_attrs = tuple(attr for attr in lora_attrs
  88. if not hasattr(model, attr))
  89. if getattr(model, "supports_lora", False):
  90. if missing_attrs:
  91. logger.warning(
  92. f"The model ({model}) sets `supports_lora=True`, "
  93. "but is missing LoRA-specific attributes: "
  94. f"{missing_attrs}", )
  95. else:
  96. if not missing_attrs:
  97. logger.warning(
  98. f"The model ({model}) contains all LoRA-specific "
  99. "attributes, but does not set `supports_lora=True`.")
  100. return result
  101. def _supports_lora(
  102. model: Union[Type[object], object],
  103. ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
  104. if isinstance(model, type):
  105. return isinstance(model, _SupportsLoRAType)
  106. return isinstance(model, SupportsLoRA)
  107. @runtime_checkable
  108. class HasInnerState(Protocol):
  109. """The interface required for all models that has inner state."""
  110. has_inner_state: ClassVar[Literal[True]] = True
  111. """
  112. A flag that indicates this model has inner state.
  113. Models that has inner state usually need access to the scheduler_config
  114. for max_num_seqs ,etc... (Currently only used by Jamba and Mamba)
  115. """
  116. def __init__(self,
  117. *,
  118. scheduler_config: Optional[SchedulerConfig] = None) -> None:
  119. ...
  120. @runtime_checkable
  121. class _HasInnerStateType(Protocol):
  122. has_inner_state: ClassVar[Literal[True]]
  123. def __init__(self,
  124. *,
  125. scheduler_config: Optional[SchedulerConfig] = None) -> None:
  126. ...
  127. @overload
  128. def has_inner_state(model: object) -> TypeIs[HasInnerState]:
  129. ...
  130. @overload
  131. def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
  132. ...
  133. def has_inner_state(
  134. model: Union[Type[object], object]
  135. ) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
  136. if isinstance(model, type):
  137. return isinstance(model, _HasInnerStateType)
  138. return isinstance(model, HasInnerState)