1
0

__init__.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import functools
  2. import importlib
  3. from typing import Dict, List, Optional, Tuple, Type
  4. import torch.nn as nn
  5. from loguru import logger
  6. from aphrodite.common.utils import is_hip
  7. _GENERATION_MODELS = {
  8. "AquilaModel": ("llama", "LlamaForCausalLM"),
  9. "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
  10. "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
  11. "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
  12. "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
  13. "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
  14. "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
  15. "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
  16. "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
  17. "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
  18. "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
  19. "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
  20. "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
  21. "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
  22. "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
  23. "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
  24. "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  25. "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
  26. "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
  27. "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
  28. "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
  29. "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
  30. "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
  31. # For decapoda-research/llama-*
  32. "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
  33. "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
  34. "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
  35. "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
  36. # transformers's mpt class has lower case
  37. "MptForCausalLM": ("mpt", "MPTForCausalLM"),
  38. "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
  39. "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
  40. "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
  41. "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
  42. "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
  43. "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
  44. "OPTForCausalLM": ("opt", "OPTForCausalLM"),
  45. "OrionForCausalLM": ("orion", "OrionForCausalLM"),
  46. "PhiForCausalLM": ("phi", "PhiForCausalLM"),
  47. "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
  48. "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
  49. "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
  50. "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
  51. "Qwen2VLForConditionalGeneration":
  52. ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
  53. "RWForCausalLM": ("falcon", "FalconForCausalLM"),
  54. "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
  55. "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
  56. "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
  57. "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
  58. "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
  59. "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
  60. "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
  61. "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
  62. "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
  63. "MedusaModel": ("medusa", "Medusa"),
  64. "EAGLEModel": ("eagle", "EAGLE"),
  65. "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
  66. "SolarForCausalLM": ("solar", "SolarForCausalLM"),
  67. "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
  68. "GraniteForCausalLM": ("granite", "GraniteForCausalLM")
  69. }
  70. _EMBEDDING_MODELS = {
  71. "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
  72. }
  73. _MULTIMODAL_MODELS = {
  74. "Blip2ForConditionalGeneration":
  75. ("blip2", "Blip2ForConditionalGeneration"),
  76. "ChameleonForConditionalGeneration":
  77. ("chameleon", "ChameleonForConditionalGeneration"),
  78. "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
  79. "InternVLChatModel": ("internvl", "InternVLChatModel"),
  80. "LlavaForConditionalGeneration":
  81. ("llava", "LlavaForConditionalGeneration"),
  82. "LlavaNextForConditionalGeneration": ("llava_next",
  83. "LlavaNextForConditionalGeneration"),
  84. "LlavaNextVideoForConditionalGeneration":
  85. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  86. "MiniCPMV": ("minicpmv", "MiniCPMV"),
  87. "PaliGemmaForConditionalGeneration": ("paligemma",
  88. "PaliGemmaForConditionalGeneration"),
  89. "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
  90. "UltravoxModel": ("ultravox", "UltravoxModel"),
  91. "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
  92. "Qwen2VLForConditionalGeneration": ("qwen2_vl",
  93. "Qwen2VLForConditionalGeneration"),
  94. "PixtralForConditionalGeneration": ("pixtral",
  95. "PixtralForConditionalGeneration"),
  96. "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
  97. }
  98. _CONDITIONAL_GENERATION_MODELS = {
  99. "BartModel": ("bart", "BartForConditionalGeneration"),
  100. "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
  101. }
  102. _MODELS = {
  103. **_GENERATION_MODELS,
  104. **_EMBEDDING_MODELS,
  105. **_MULTIMODAL_MODELS,
  106. **_CONDITIONAL_GENERATION_MODELS,
  107. }
  108. # Architecture -> type.
  109. # out of tree models
  110. _OOT_MODELS: Dict[str, Type[nn.Module]] = {}
  111. # Models not supported by ROCm.
  112. _ROCM_UNSUPPORTED_MODELS = []
  113. # Models partially supported by ROCm.
  114. # Architecture -> Reason.
  115. _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
  116. "Triton flash attention. For half-precision SWA support, "
  117. "please use CK flash attention by setting "
  118. "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  119. _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
  120. "Qwen2ForCausalLM":
  121. _ROCM_SWA_REASON,
  122. "MistralForCausalLM":
  123. _ROCM_SWA_REASON,
  124. "MixtralForCausalLM":
  125. _ROCM_SWA_REASON,
  126. "PaliGemmaForConditionalGeneration":
  127. ("ROCm flash attention does not yet "
  128. "fully support 32-bit precision on PaliGemma"),
  129. "Phi3VForCausalLM":
  130. ("ROCm Triton flash attention may run into compilation errors due to "
  131. "excessive use of shared memory. If this happens, disable Triton FA "
  132. "by setting `APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  133. }
  134. class ModelRegistry:
  135. @staticmethod
  136. @functools.lru_cache(maxsize=128)
  137. def _get_model(model_arch: str):
  138. module_name, model_cls_name = _MODELS[model_arch]
  139. module = importlib.import_module(
  140. f"aphrodite.modeling.models.{module_name}")
  141. return getattr(module, model_cls_name, None)
  142. @staticmethod
  143. def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
  144. if model_arch in _OOT_MODELS:
  145. return _OOT_MODELS[model_arch]
  146. if model_arch not in _MODELS:
  147. return None
  148. if is_hip():
  149. if model_arch in _ROCM_UNSUPPORTED_MODELS:
  150. raise ValueError(
  151. f"Model architecture {model_arch} is not supported by "
  152. "ROCm for now.")
  153. if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
  154. logger.warning(
  155. f"Model architecture {model_arch} is partially "
  156. "supported by ROCm: "
  157. f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]}")
  158. return ModelRegistry._get_model(model_arch)
  159. @staticmethod
  160. def resolve_model_cls(
  161. architectures: List[str]) -> Tuple[Type[nn.Module], str]:
  162. for arch in architectures:
  163. model_cls = ModelRegistry._try_load_model_cls(arch)
  164. if model_cls is not None:
  165. return (model_cls, arch)
  166. raise ValueError(
  167. f"Model architectures {architectures} are not supported for now. "
  168. f"Supported architectures: {ModelRegistry.get_supported_archs()}")
  169. @staticmethod
  170. def get_supported_archs() -> List[str]:
  171. return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
  172. @staticmethod
  173. def register_model(model_arch: str, model_cls: Type[nn.Module]):
  174. if model_arch in _MODELS:
  175. logger.warning(f"Model architecture {model_arch} is already "
  176. "registered, and will be overwritten by the new "
  177. f"model class {model_cls.__name__}.")
  178. global _OOT_MODELS
  179. _OOT_MODELS[model_arch] = model_cls
  180. @staticmethod
  181. def is_embedding_model(model_arch: str) -> bool:
  182. return model_arch in _EMBEDDING_MODELS
  183. @staticmethod
  184. def is_multimodal_model(model_arch: str) -> bool:
  185. # TODO: find a way to avoid initializing CUDA prematurely to
  186. # use `supports_multimodal` to determine if a model is multimodal
  187. # model_cls = ModelRegistry._try_load_model_cls(model_arch)
  188. # from aphrodite.modeling.models.interfaces import supports_multimodal
  189. return model_arch in _MULTIMODAL_MODELS
  190. __all__ = [
  191. "ModelRegistry",
  192. ]