123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import importlib
- from typing import List, Optional, Type
- from loguru import logger
- import torch.nn as nn
- from aphrodite.common.utils import is_hip, is_neuron
- # Architecture -> (module, class).
- _MODELS = {
- "AquilaModel": ("llama", "LlamaForCausalLM"),
- "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
- "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
- "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
- "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
- "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
- "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
- "CohereForCausalLM": ("cohere", "CohereForCausalLM"),
- "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
- "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
- "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
- "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
- "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
- "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
- "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
- "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
- "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
- "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
- # For decapoda-research/llama-*
- "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
- "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
- "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
- "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
- "YiForCausalLM": ("llama", "LlamaForCausalLM"),
- # transformers's mpt class has lower case
- "MptForCausalLM": ("mpt", "MPTForCausalLM"),
- "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
- "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
- "OPTForCausalLM": ("opt", "OPTForCausalLM"),
- "PhiForCausalLM": ("phi", "PhiForCausalLM"),
- "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
- "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
- "RWForCausalLM": ("falcon", "FalconForCausalLM"),
- "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
- "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
- }
- # Models not supported by ROCm.
- _ROCM_UNSUPPORTED_MODELS = []
- # Models partially supported by ROCm.
- # Architecture -> Reason.
- _ROCM_PARTIALLY_SUPPORTED_MODELS = {
- "Qwen2ForCausalLM":
- "Sliding window attention is not yet supported in ROCm's flash attention",
- "MistralForCausalLM":
- "Sliding window attention is not yet supported in ROCm's flash attention",
- "MixtralForCausalLM":
- "Sliding window attention is not yet supported in ROCm's flash attention",
- }
- _NEURON_SUPPORTED_MODELS = {
- "LlamaForCausalLM": "neuron.llama",
- "MistralForCausalLM": "neuron.mistral",
- }
- class ModelRegistry:
- @staticmethod
- def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
- if model_arch not in _MODELS:
- return None
- if is_hip():
- if model_arch in _ROCM_UNSUPPORTED_MODELS:
- raise ValueError(
- f"Model architecture {model_arch} is not supported by "
- "ROCm for now.")
- if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
- logger.warning(
- f"Model architecture {model_arch} is partially supported "
- "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
- elif is_neuron():
- if model_arch not in _NEURON_SUPPORTED_MODELS:
- raise ValueError(
- f"Model architecture {model_arch} is not supported by "
- "AWS Neuron for now.")
- module_name, model_cls_name = _MODELS[model_arch]
- if is_neuron():
- module_name = _NEURON_SUPPORTED_MODELS[model_arch]
- module = importlib.import_module(
- f"aphrodite.modeling.models.{module_name}")
- return getattr(module, model_cls_name, None)
- @staticmethod
- def get_supported_archs() -> List[str]:
- return list(_MODELS.keys())
- __all__ = [
- "ModelRegistry",
- ]
|