__init__.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import importlib
  2. from typing import List, Optional, Type
  3. from loguru import logger
  4. import torch.nn as nn
  5. from aphrodite.common.utils import is_hip, is_neuron
  6. # Architecture -> (module, class).
  7. _MODELS = {
  8. "AquilaModel": ("llama", "LlamaForCausalLM"),
  9. "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
  10. "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
  11. "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
  12. "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
  13. "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
  14. "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
  15. "CohereForCausalLM": ("cohere", "CohereForCausalLM"),
  16. "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
  17. "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
  18. "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
  19. "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
  20. "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
  21. "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  22. "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
  23. "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
  24. "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
  25. "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
  26. "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
  27. # For decapoda-research/llama-*
  28. "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
  29. "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
  30. "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
  31. "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
  32. "YiForCausalLM": ("llama", "LlamaForCausalLM"),
  33. # transformers's mpt class has lower case
  34. "MptForCausalLM": ("mpt", "MPTForCausalLM"),
  35. "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
  36. "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
  37. "OPTForCausalLM": ("opt", "OPTForCausalLM"),
  38. "PhiForCausalLM": ("phi", "PhiForCausalLM"),
  39. "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
  40. "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
  41. "RWForCausalLM": ("falcon", "FalconForCausalLM"),
  42. "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
  43. "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
  44. }
  45. # Models not supported by ROCm.
  46. _ROCM_UNSUPPORTED_MODELS = []
  47. # Models partially supported by ROCm.
  48. # Architecture -> Reason.
  49. _ROCM_PARTIALLY_SUPPORTED_MODELS = {
  50. "Qwen2ForCausalLM":
  51. "Sliding window attention is not yet supported in ROCm's flash attention",
  52. "MistralForCausalLM":
  53. "Sliding window attention is not yet supported in ROCm's flash attention",
  54. "MixtralForCausalLM":
  55. "Sliding window attention is not yet supported in ROCm's flash attention",
  56. }
  57. _NEURON_SUPPORTED_MODELS = {
  58. "LlamaForCausalLM": "neuron.llama",
  59. "MistralForCausalLM": "neuron.mistral",
  60. }
  61. class ModelRegistry:
  62. @staticmethod
  63. def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
  64. if model_arch not in _MODELS:
  65. return None
  66. if is_hip():
  67. if model_arch in _ROCM_UNSUPPORTED_MODELS:
  68. raise ValueError(
  69. f"Model architecture {model_arch} is not supported by "
  70. "ROCm for now.")
  71. if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
  72. logger.warning(
  73. f"Model architecture {model_arch} is partially supported "
  74. "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
  75. elif is_neuron():
  76. if model_arch not in _NEURON_SUPPORTED_MODELS:
  77. raise ValueError(
  78. f"Model architecture {model_arch} is not supported by "
  79. "AWS Neuron for now.")
  80. module_name, model_cls_name = _MODELS[model_arch]
  81. if is_neuron():
  82. module_name = _NEURON_SUPPORTED_MODELS[model_arch]
  83. module = importlib.import_module(
  84. f"aphrodite.modeling.models.{module_name}")
  85. return getattr(module, model_cls_name, None)
  86. @staticmethod
  87. def get_supported_archs() -> List[str]:
  88. return list(_MODELS.keys())
  89. __all__ = [
  90. "ModelRegistry",
  91. ]