__init__.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import importlib
  2. from typing import List, Optional, Type
  3. import torch.nn as nn
  4. from aphrodite.common.logger import init_logger
  5. from aphrodite.common.utils import is_hip
  6. logger = init_logger(__name__)
  7. # Architecture -> (module, class)
  8. _MODELS = {
  9. "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
  10. "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
  11. "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
  12. "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
  13. "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
  14. "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
  15. "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
  16. "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
  17. "PhiForCausalLM": ("phi", "PhiForCausalLM"),
  18. "YiForCausalLM": ("yi", "YiForCausalLM"),
  19. }
  20. # Models not supported by ROCm
  21. _ROCM_UNSUPPORTED_MODELS = []
  22. # Models partially supported by ROCm.
  23. # Architecture -> Reason
  24. _ROCM_PARTIALLY_SUPPORTED_MODELS = {
  25. "MistralForCausalLM":
  26. "Sliding window attention is not yet supported in ROCM's flash attention.",
  27. "MixtralForCausalLM":
  28. "Sliding window attention is not yet supported in ROCm's flash attention",
  29. }
  30. class ModelRegistry:
  31. @staticmethod
  32. def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
  33. if model_arch not in _MODELS:
  34. return None
  35. if is_hip():
  36. if model_arch in _ROCM_UNSUPPORTED_MODELS:
  37. raise ValueError(f"Model architecture {model_arch} is not "
  38. "supported in ROCm for now.")
  39. if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
  40. logger.warning(
  41. f"Model architecture {model_arch} is partially supported "
  42. "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
  43. module_name, model_cls_name = _MODELS[model_arch]
  44. module = importlib.import_module(
  45. f"aphrodite.modeling.models.{module_name}")
  46. return getattr(module, model_cls_name, None)
  47. @staticmethod
  48. def get_supported_archs() -> List[str]:
  49. return list(_MODELS.keys())
  50. __all__ = ["ModelRegistry"]