interfaces.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
  2. Union, overload, runtime_checkable)
  3. from loguru import logger
  4. from typing_extensions import TypeGuard
  5. from aphrodite.common.config import LoRAConfig, VisionLanguageConfig
  6. @runtime_checkable
  7. class SupportsVision(Protocol):
  8. """The interface required for all vision language models (VLMs)."""
  9. supports_vision: ClassVar[Literal[True]] = True
  10. """
  11. A flag that indicates this model supports vision inputs.
  12. Note:
  13. There is no need to redefine this flag if this class is in the
  14. MRO of your model class.
  15. """
  16. def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
  17. ...
  18. # We can't use runtime_checkable with ClassVar for issubclass checks
  19. # so we need to treat the class as an instance and use isinstance instead
  20. @runtime_checkable
  21. class _SupportsVisionType(Protocol):
  22. supports_vision: Literal[True]
  23. def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
  24. ...
  25. @overload
  26. def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
  27. ...
  28. @overload
  29. def supports_vision(model: object) -> TypeGuard[SupportsVision]:
  30. ...
  31. def supports_vision(
  32. model: Union[Type[object], object],
  33. ) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
  34. if isinstance(model, type):
  35. return isinstance(model, _SupportsVisionType)
  36. return isinstance(model, SupportsVision)
  37. @runtime_checkable
  38. class SupportsLoRA(Protocol):
  39. """The interface required for all models that support LoRA."""
  40. supports_lora: ClassVar[Literal[True]] = True
  41. """
  42. A flag that indicates this model supports LoRA.
  43. Note:
  44. There is no need to redefine this flag if this class is in the
  45. MRO of your model class.
  46. """
  47. packed_modules_mapping: ClassVar[Dict[str, List[str]]]
  48. supported_lora_modules: ClassVar[List[str]]
  49. embedding_modules: ClassVar[Dict[str, str]]
  50. embedding_padding_modules: ClassVar[List[str]]
  51. # lora_config is None when LoRA is not enabled
  52. def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  53. ...
  54. # We can't use runtime_checkable with ClassVar for issubclass checks
  55. # so we need to treat the class as an instance and use isinstance instead
  56. @runtime_checkable
  57. class _SupportsLoRAType(Protocol):
  58. supports_lora: Literal[True]
  59. packed_modules_mapping: Dict[str, List[str]]
  60. supported_lora_modules: List[str]
  61. embedding_modules: Dict[str, str]
  62. embedding_padding_modules: List[str]
  63. def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
  64. ...
  65. @overload
  66. def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
  67. ...
  68. @overload
  69. def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
  70. ...
  71. def supports_lora(
  72. model: Union[Type[object], object],
  73. ) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
  74. result = _supports_lora(model)
  75. if not result:
  76. lora_attrs = (
  77. "packed_modules_mapping",
  78. "supported_lora_modules",
  79. "embedding_modules",
  80. "embedding_padding_modules",
  81. )
  82. missing_attrs = tuple(attr for attr in lora_attrs
  83. if not hasattr(model, attr))
  84. if getattr(model, "supports_lora", False):
  85. if missing_attrs:
  86. logger.warning(
  87. f"The model ({model}) sets `supports_lora=True`, "
  88. "but is missing LoRA-specific attributes: "
  89. f"{missing_attrs}", )
  90. else:
  91. if not missing_attrs:
  92. logger.warning(
  93. f"The model ({model}) contains all LoRA-specific "
  94. "attributes, but does not set `supports_lora=True`.")
  95. return result
  96. def _supports_lora(
  97. model: Union[Type[object], object],
  98. ) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
  99. if isinstance(model, type):
  100. return isinstance(model, _SupportsLoRAType)
  101. return isinstance(model, SupportsLoRA)