interfaces.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
  2. Union, overload, runtime_checkable)
  3. from loguru import logger
  4. from typing_extensions import TypeGuard
  5. from aphrodite.common.config import (LoRAConfig, MultiModalConfig,
  6. SchedulerConfig)
  7. @runtime_checkable
  8. class SupportsVision(Protocol):
  9. """The interface required for all vision language models (VLMs)."""
  10. supports_vision: ClassVar[Literal[True]] = True
  11. """
  12. A flag that indicates this model supports vision inputs.
  13. Note:
  14. There is no need to redefine this flag if this class is in the
  15. MRO of your model class.
  16. """
  17. def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
  18. ...
  19. # We can't use runtime_checkable with ClassVar for issubclass checks
  20. # so we need to treat the class as an instance and use isinstance instead
  21. @runtime_checkable
  22. class _SupportsVisionType(Protocol):
  23. supports_vision: Literal[True]
  24. def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
  25. ...
  26. @overload
  27. def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
  28. ...
  29. @overload
  30. def supports_vision(model: object) -> TypeGuard[SupportsVision]:
  31. ...
  32. def supports_vision(
  33. model: Union[Type[object], object],
  34. ) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
  35. if isinstance(model, type):
  36. return isinstance(model, _SupportsVisionType)
  37. return isinstance(model, SupportsVision)
  38. @runtime_checkable
  39. class SupportsLoRA(Protocol):
  40. """The interface required for all models that support LoRA."""
  41. supports_lora: ClassVar[Literal[True]] = True
  42. """
  43. A flag that indicates this model supports LoRA.
  44. Note:
  45. There is no need to redefine this flag if this class is in the
  46. MRO of your model class.
  47. """
  48. packed_modules_mapping: ClassVar[Dict[str, List[str]]]
  49. supported_lora_modules: ClassVar[List[str]]
  50. embedding_modules: ClassVar[Dict[str, str]]
  51. embedding_padding_modules: ClassVar[List[str]]
  52. # lora_config is None when LoRA is not enabled
  53. def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  54. ...
  55. # We can't use runtime_checkable with ClassVar for issubclass checks
  56. # so we need to treat the class as an instance and use isinstance instead
  57. @runtime_checkable
  58. class _SupportsLoRAType(Protocol):
  59. supports_lora: Literal[True]
  60. packed_modules_mapping: Dict[str, List[str]]
  61. supported_lora_modules: List[str]
  62. embedding_modules: Dict[str, str]
  63. embedding_padding_modules: List[str]
  64. def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  65. ...
  66. @overload
  67. def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
  68. ...
  69. @overload
  70. def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
  71. ...
  72. def supports_lora(
  73. model: Union[Type[object], object],
  74. ) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
  75. result = _supports_lora(model)
  76. if not result:
  77. lora_attrs = (
  78. "packed_modules_mapping",
  79. "supported_lora_modules",
  80. "embedding_modules",
  81. "embedding_padding_modules",
  82. )
  83. missing_attrs = tuple(attr for attr in lora_attrs
  84. if not hasattr(model, attr))
  85. if getattr(model, "supports_lora", False):
  86. if missing_attrs:
  87. logger.warning(
  88. f"The model ({model}) sets `supports_lora=True`, "
  89. "but is missing LoRA-specific attributes: "
  90. f"{missing_attrs}", )
  91. else:
  92. if not missing_attrs:
  93. logger.warning(
  94. f"The model ({model}) contains all LoRA-specific "
  95. "attributes, but does not set `supports_lora=True`.")
  96. return result
  97. def _supports_lora(
  98. model: Union[Type[object], object],
  99. ) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
  100. if isinstance(model, type):
  101. return isinstance(model, _SupportsLoRAType)
  102. return isinstance(model, SupportsLoRA)
  103. @runtime_checkable
  104. class HasInnerState(Protocol):
  105. """The interface required for all models that has inner state."""
  106. has_inner_state: ClassVar[Literal[True]] = True
  107. """
  108. A flag that indicates this model has inner state.
  109. Models that has inner state usually need access to the scheduler_config
  110. for max_num_seqs ,etc... (Currently only used by Jamba)
  111. """
  112. def __init__(self,
  113. *,
  114. scheduler_config: Optional[SchedulerConfig] = None) -> None:
  115. ...
  116. @runtime_checkable
  117. class _HasInnerStateType(Protocol):
  118. has_inner_state: ClassVar[Literal[True]]
  119. def __init__(self,
  120. *,
  121. scheduler_config: Optional[SchedulerConfig] = None) -> None:
  122. ...
  123. @overload
  124. def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
  125. ...
  126. @overload
  127. def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
  128. ...
  129. def has_inner_state(
  130. model: Union[Type[object], object]
  131. ) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
  132. if isinstance(model, type):
  133. return isinstance(model, _HasInnerStateType)
  134. return isinstance(model, HasInnerState)