__init__.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import importlib
  2. from typing import Dict, List, Optional, Type
  3. import torch.nn as nn
  4. from loguru import logger
  5. from aphrodite.common.utils import is_hip
  6. # Architecture -> (module, class).
  7. _GENERATION_MODELS = {
  8. "AquilaModel": ("llama", "LlamaForCausalLM"),
  9. "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
  10. "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
  11. "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
  12. "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
  13. "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
  14. "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
  15. "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
  16. "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
  17. "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
  18. "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
  19. "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
  20. "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
  21. "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
  22. "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  23. "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
  24. "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
  25. "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
  26. "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
  27. "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
  28. "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
  29. "LlavaForConditionalGeneration":
  30. ("llava", "LlavaForConditionalGeneration"),
  31. # For decapoda-research/llama-*
  32. "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
  33. "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
  34. "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
  35. "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
  36. # transformers's mpt class has lower case
  37. "MptForCausalLM": ("mpt", "MPTForCausalLM"),
  38. "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
  39. "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
  40. "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
  41. "OPTForCausalLM": ("opt", "OPTForCausalLM"),
  42. "OrionForCausalLM": ("orion", "OrionForCausalLM"),
  43. "PhiForCausalLM": ("phi", "PhiForCausalLM"),
  44. "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
  45. "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
  46. "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
  47. "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
  48. "RWForCausalLM": ("falcon", "FalconForCausalLM"),
  49. "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
  50. "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
  51. "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
  52. "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
  53. "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
  54. }
  55. _EMBEDDING_MODELS = {
  56. "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
  57. }
  58. _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
  59. # Architecture -> type.
  60. # out of tree models
  61. _OOT_MODELS: Dict[str, Type[nn.Module]] = {}
  62. # Models not supported by ROCm.
  63. _ROCM_UNSUPPORTED_MODELS = []
  64. # Models partially supported by ROCm.
  65. # Architecture -> Reason.
  66. _ROCM_PARTIALLY_SUPPORTED_MODELS = {
  67. "Qwen2ForCausalLM":
  68. "Sliding window attention is not yet supported in ROCm's flash attention",
  69. "MistralForCausalLM":
  70. "Sliding window attention is not yet supported in ROCm's flash attention",
  71. "MixtralForCausalLM":
  72. "Sliding window attention is not yet supported in ROCm's flash attention",
  73. }
  74. class ModelRegistry:
  75. @staticmethod
  76. def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
  77. if model_arch in _OOT_MODELS:
  78. return _OOT_MODELS[model_arch]
  79. if model_arch not in _MODELS:
  80. return None
  81. if is_hip():
  82. if model_arch in _ROCM_UNSUPPORTED_MODELS:
  83. raise ValueError(
  84. f"Model architecture {model_arch} is not supported by "
  85. "ROCm for now.")
  86. if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
  87. logger.warning(
  88. f"Model architecture {model_arch} is partially "
  89. "supported by ROCm: "
  90. f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]}")
  91. module_name, model_cls_name = _MODELS[model_arch]
  92. module = importlib.import_module(
  93. f"aphrodite.modeling.models.{module_name}")
  94. return getattr(module, model_cls_name, None)
  95. @staticmethod
  96. def get_supported_archs() -> List[str]:
  97. return list(_MODELS.keys())
  98. @staticmethod
  99. def register_model(model_arch: str, model_cls: Type[nn.Module]):
  100. if model_arch in _MODELS:
  101. logger.warning(f"Model architecture {model_arch} is already "
  102. "registered, and will be overwritten by the new "
  103. f"model class {model_cls.__name__}.")
  104. global _OOT_MODELS
  105. _OOT_MODELS[model_arch] = model_cls
  106. @staticmethod
  107. def is_embedding_model(model_arch: str) -> bool:
  108. return model_arch in _EMBEDDING_MODELS
  109. __all__ = [
  110. "ModelRegistry",
  111. ]