123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- import functools
- import importlib
- from typing import Dict, List, Optional, Tuple, Type
- import torch.nn as nn
- from loguru import logger
- from aphrodite.common.utils import is_hip
- # Architecture -> (module, class).
- _GENERATION_MODELS = {
- "AquilaModel": ("llama", "LlamaForCausalLM"),
- "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
- "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
- "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
- "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
- "Blip2ForConditionalGeneration":
- ("blip2", "Blip2ForConditionalGeneration"),
- "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
- "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
- "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
- "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
- "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
- "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
- "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
- "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
- "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
- "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
- "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
- "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
- "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
- "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
- "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
- "InternVLChatModel": ("internvl", "InternVLChatModel"),
- "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
- "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
- "LlavaForConditionalGeneration":
- ("llava", "LlavaForConditionalGeneration"),
- "LlavaNextForConditionalGeneration":
- ("llava_next", "LlavaNextForConditionalGeneration"),
- # For decapoda-research/llama-*
- "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
- "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
- "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
- "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
- # transformers's mpt class has lower case
- "MptForCausalLM": ("mpt", "MPTForCausalLM"),
- "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
- "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
- "MiniCPMV": ("minicpmv", "MiniCPMV"),
- "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
- "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
- "OPTForCausalLM": ("opt", "OPTForCausalLM"),
- "OrionForCausalLM": ("orion", "OrionForCausalLM"),
- "PaliGemmaForConditionalGeneration":
- ("paligemma", "PaliGemmaForConditionalGeneration"),
- "PhiForCausalLM": ("phi", "PhiForCausalLM"),
- "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
- "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
- "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
- "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
- "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
- "RWForCausalLM": ("falcon", "FalconForCausalLM"),
- "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
- "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
- "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
- "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
- "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
- "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
- "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
- "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
- "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
- "MedusaModel": ("medusa", "Medusa"),
- "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
- "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
- "ChameleonForConditionalGeneration": ("chameleon",
- "ChameleonForConditionalGeneration"),
- }
- _EMBEDDING_MODELS = {
- "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
- }
- _CONDITIONAL_GENERATION_MODELS = {
- "BartModel": ("bart", "BartForConditionalGeneration"),
- "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
- }
- _MODELS = {
- **_GENERATION_MODELS,
- **_EMBEDDING_MODELS,
- **_CONDITIONAL_GENERATION_MODELS
- }
- # Architecture -> type.
- # out of tree models
- _OOT_MODELS: Dict[str, Type[nn.Module]] = {}
- # Models not supported by ROCm.
- _ROCM_UNSUPPORTED_MODELS = []
- # Models partially supported by ROCm.
- # Architecture -> Reason.
- _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
- "Triton flash attention. For half-precision SWA support, "
- "please use CK flash attention by setting "
- "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
- _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
- "Qwen2ForCausalLM":
- _ROCM_SWA_REASON,
- "MistralForCausalLM":
- _ROCM_SWA_REASON,
- "MixtralForCausalLM":
- _ROCM_SWA_REASON,
- "PaliGemmaForConditionalGeneration":
- ("ROCm flash attention does not yet "
- "fully support 32-bit precision on PaliGemma"),
- "Phi3VForCausalLM":
- ("ROCm Triton flash attention may run into compilation errors due to "
- "excessive use of shared memory. If this happens, disable Triton FA "
- "by setting `APHRODITE_USE_TRITON_FLASH_ATTN=0`")
- }
- class ModelRegistry:
- @staticmethod
- @functools.lru_cache(maxsize=128)
- def _get_model(model_arch: str):
- module_name, model_cls_name = _MODELS[model_arch]
- module = importlib.import_module(
- f"aphrodite.modeling.models.{module_name}")
- return getattr(module, model_cls_name, None)
- @staticmethod
- def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
- if model_arch in _OOT_MODELS:
- return _OOT_MODELS[model_arch]
- if model_arch not in _MODELS:
- return None
- if is_hip():
- if model_arch in _ROCM_UNSUPPORTED_MODELS:
- raise ValueError(
- f"Model architecture {model_arch} is not supported by "
- "ROCm for now.")
- if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
- logger.warning(
- f"Model architecture {model_arch} is partially "
- "supported by ROCm: "
- f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]}")
- return ModelRegistry._get_model(model_arch)
- @staticmethod
- def resolve_model_cls(
- architectures: List[str]) -> Tuple[Type[nn.Module], str]:
- for arch in architectures:
- model_cls = ModelRegistry._try_load_model_cls(arch)
- if model_cls is not None:
- return (model_cls, arch)
- raise ValueError(
- f"Model architectures {architectures} are not supported for now. "
- f"Supported architectures: {ModelRegistry.get_supported_archs()}")
- @staticmethod
- def get_supported_archs() -> List[str]:
- return list(_MODELS.keys())
- @staticmethod
- def register_model(model_arch: str, model_cls: Type[nn.Module]):
- if model_arch in _MODELS:
- logger.warning(f"Model architecture {model_arch} is already "
- "registered, and will be overwritten by the new "
- f"model class {model_cls.__name__}.")
- global _OOT_MODELS
- _OOT_MODELS[model_arch] = model_cls
- @staticmethod
- def is_embedding_model(model_arch: str) -> bool:
- return model_arch in _EMBEDDING_MODELS
- __all__ = [
- "ModelRegistry",
- ]
|