1
0

interfaces.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
  2. Union, overload, runtime_checkable)
  3. from loguru import logger
  4. from typing_extensions import TypeIs
  5. from aphrodite.common.config import (LoRAConfig, MultiModalConfig,
  6. SchedulerConfig)
  7. @runtime_checkable
  8. class SupportsMultiModal(Protocol):
  9. """The interface required for all multi-modal models."""
  10. supports_multimodal: ClassVar[Literal[True]] = True
  11. """
  12. A flag that indicates this model supports multi-modal inputs.
  13. Note:
  14. There is no need to redefine this flag if this class is in the
  15. MRO of your model class.
  16. """
  17. def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
  18. ...
  19. # We can't use runtime_checkable with ClassVar for issubclass checks
  20. # so we need to treat the class as an instance and use isinstance instead
  21. @runtime_checkable
  22. class _SupportsMultiModalType(Protocol):
  23. supports_multimodal: Literal[True]
  24. def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
  25. ...
  26. @overload
  27. def supports_multimodal(
  28. model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
  29. ...
  30. @overload
  31. def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
  32. ...
  33. def supports_multimodal(
  34. model: Union[Type[object], object],
  35. ) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
  36. if isinstance(model, type):
  37. return isinstance(model, _SupportsMultiModalType)
  38. return isinstance(model, SupportsMultiModal)
  39. @runtime_checkable
  40. class SupportsLoRA(Protocol):
  41. """The interface required for all models that support LoRA."""
  42. supports_lora: ClassVar[Literal[True]] = True
  43. """
  44. A flag that indicates this model supports LoRA.
  45. Note:
  46. There is no need to redefine this flag if this class is in the
  47. MRO of your model class.
  48. """
  49. packed_modules_mapping: ClassVar[Dict[str, List[str]]]
  50. supported_lora_modules: ClassVar[List[str]]
  51. embedding_modules: ClassVar[Dict[str, str]]
  52. embedding_padding_modules: ClassVar[List[str]]
  53. # lora_config is None when LoRA is not enabled
  54. def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  55. ...
  56. # We can't use runtime_checkable with ClassVar for issubclass checks
  57. # so we need to treat the class as an instance and use isinstance instead
  58. @runtime_checkable
  59. class _SupportsLoRAType(Protocol):
  60. supports_lora: Literal[True]
  61. packed_modules_mapping: Dict[str, List[str]]
  62. supported_lora_modules: List[str]
  63. embedding_modules: Dict[str, str]
  64. embedding_padding_modules: List[str]
  65. def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  66. ...
  67. @overload
  68. def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
  69. ...
  70. @overload
  71. def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
  72. ...
  73. def supports_lora(
  74. model: Union[Type[object], object],
  75. ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
  76. result = _supports_lora(model)
  77. if not result:
  78. lora_attrs = (
  79. "packed_modules_mapping",
  80. "supported_lora_modules",
  81. "embedding_modules",
  82. "embedding_padding_modules",
  83. )
  84. missing_attrs = tuple(attr for attr in lora_attrs
  85. if not hasattr(model, attr))
  86. if getattr(model, "supports_lora", False):
  87. if missing_attrs:
  88. logger.warning(
  89. f"The model ({model}) sets `supports_lora=True`, "
  90. "but is missing LoRA-specific attributes: "
  91. f"{missing_attrs}", )
  92. else:
  93. if not missing_attrs:
  94. logger.warning(
  95. f"The model ({model}) contains all LoRA-specific "
  96. "attributes, but does not set `supports_lora=True`.")
  97. return result
  98. def _supports_lora(
  99. model: Union[Type[object], object],
  100. ) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
  101. if isinstance(model, type):
  102. return isinstance(model, _SupportsLoRAType)
  103. return isinstance(model, SupportsLoRA)
  104. @runtime_checkable
  105. class HasInnerState(Protocol):
  106. """The interface required for all models that has inner state."""
  107. has_inner_state: ClassVar[Literal[True]] = True
  108. """
  109. A flag that indicates this model has inner state.
  110. Models that has inner state usually need access to the scheduler_config
  111. for max_num_seqs ,etc... (Currently only used by Jamba and Mamba)
  112. """
  113. def __init__(self,
  114. *,
  115. scheduler_config: Optional[SchedulerConfig] = None) -> None:
  116. ...
  117. @runtime_checkable
  118. class _HasInnerStateType(Protocol):
  119. has_inner_state: ClassVar[Literal[True]]
  120. def __init__(self,
  121. *,
  122. scheduler_config: Optional[SchedulerConfig] = None) -> None:
  123. ...
  124. @overload
  125. def has_inner_state(model: object) -> TypeIs[HasInnerState]:
  126. ...
  127. @overload
  128. def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
  129. ...
  130. def has_inner_state(
  131. model: Union[Type[object], object]
  132. ) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
  133. if isinstance(model, type):
  134. return isinstance(model, _HasInnerStateType)
  135. return isinstance(model, HasInnerState)