__init__.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import functools
  2. import importlib
  3. from typing import Dict, List, Optional, Type
  4. import torch.nn as nn
  5. from loguru import logger
  6. from aphrodite.common.utils import is_hip
  7. # Architecture -> (module, class).
  8. _GENERATION_MODELS = {
  9. "AquilaModel": ("llama", "LlamaForCausalLM"),
  10. "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
  11. "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
  12. "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
  13. "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
  14. "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
  15. "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
  16. "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
  17. "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
  18. "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
  19. "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
  20. "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
  21. "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
  22. "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
  23. "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
  24. "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
  25. "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  26. "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
  27. "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
  28. "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
  29. "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
  30. "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
  31. "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
  32. "LlavaForConditionalGeneration":
  33. ("llava", "LlavaForConditionalGeneration"),
  34. "LlavaNextForConditionalGeneration":
  35. ("llava_next", "LlavaNextForConditionalGeneration"),
  36. # For decapoda-research/llama-*
  37. "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
  38. "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
  39. "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
  40. "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
  41. # transformers's mpt class has lower case
  42. "MptForCausalLM": ("mpt", "MPTForCausalLM"),
  43. "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
  44. "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
  45. "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
  46. "OPTForCausalLM": ("opt", "OPTForCausalLM"),
  47. "OrionForCausalLM": ("orion", "OrionForCausalLM"),
  48. "PaliGemmaForConditionalGeneration":
  49. ("paligemma", "PaliGemmaForConditionalGeneration"),
  50. "PhiForCausalLM": ("phi", "PhiForCausalLM"),
  51. "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
  52. "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
  53. "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
  54. "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
  55. "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
  56. "RWForCausalLM": ("falcon", "FalconForCausalLM"),
  57. "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
  58. "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
  59. "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
  60. "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
  61. "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
  62. "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
  63. "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
  64. "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
  65. "MedusaModel": ("medusa", "Medusa"),
  66. "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
  67. "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
  68. }
  69. _EMBEDDING_MODELS = {
  70. "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
  71. }
  72. _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
  73. # Architecture -> type.
  74. # out of tree models
  75. _OOT_MODELS: Dict[str, Type[nn.Module]] = {}
  76. # Models not supported by ROCm.
  77. _ROCM_UNSUPPORTED_MODELS = []
  78. # Models partially supported by ROCm.
  79. # Architecture -> Reason.
  80. _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
  81. "Triton flash attention. For half-precision SWA support, "
  82. "please use CK flash attention by setting "
  83. "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  84. _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
  85. "Qwen2ForCausalLM":
  86. _ROCM_SWA_REASON,
  87. "MistralForCausalLM":
  88. _ROCM_SWA_REASON,
  89. "MixtralForCausalLM":
  90. _ROCM_SWA_REASON,
  91. "PaliGemmaForConditionalGeneration":
  92. ("ROCm flash attention does not yet "
  93. "fully support 32-bit precision on PaliGemma"),
  94. "Phi3VForCausalLM":
  95. ("ROCm Triton flash attention may run into compilation errors due to "
  96. "excessive use of shared memory. If this happens, disable Triton FA "
  97. "by setting `APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  98. }
  99. class ModelRegistry:
  100. @staticmethod
  101. @functools.lru_cache(maxsize=128)
  102. def _get_model(model_arch: str):
  103. module_name, model_cls_name = _MODELS[model_arch]
  104. module = importlib.import_module(
  105. f"aphrodite.modeling.models.{module_name}")
  106. return getattr(module, model_cls_name, None)
  107. @staticmethod
  108. def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
  109. if model_arch in _OOT_MODELS:
  110. return _OOT_MODELS[model_arch]
  111. if model_arch not in _MODELS:
  112. return None
  113. if is_hip():
  114. if model_arch in _ROCM_UNSUPPORTED_MODELS:
  115. raise ValueError(
  116. f"Model architecture {model_arch} is not supported by "
  117. "ROCm for now.")
  118. if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
  119. logger.warning(
  120. f"Model architecture {model_arch} is partially "
  121. "supported by ROCm: "
  122. f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]}")
  123. return ModelRegistry._get_model(model_arch)
  124. @staticmethod
  125. def get_supported_archs() -> List[str]:
  126. return list(_MODELS.keys())
  127. @staticmethod
  128. def register_model(model_arch: str, model_cls: Type[nn.Module]):
  129. if model_arch in _MODELS:
  130. logger.warning(f"Model architecture {model_arch} is already "
  131. "registered, and will be overwritten by the new "
  132. f"model class {model_cls.__name__}.")
  133. global _OOT_MODELS
  134. _OOT_MODELS[model_arch] = model_cls
  135. @staticmethod
  136. def is_embedding_model(model_arch: str) -> bool:
  137. return model_arch in _EMBEDDING_MODELS
  138. __all__ = [
  139. "ModelRegistry",
  140. ]