utils.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. """Utilities for selecting and loading models."""
  2. import contextlib
  3. from typing import Tuple, Type
  4. import torch
  5. from torch import nn
  6. from aphrodite.common.config import ModelConfig
  7. from aphrodite.modeling.models import ModelRegistry
  8. @contextlib.contextmanager
  9. def set_default_torch_dtype(dtype: torch.dtype):
  10. """Sets the default torch dtype to the given dtype."""
  11. old_dtype = torch.get_default_dtype()
  12. torch.set_default_dtype(dtype)
  13. yield
  14. torch.set_default_dtype(old_dtype)
  15. def get_model_architecture(
  16. model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
  17. architectures = getattr(model_config.hf_config, "architectures", [])
  18. # Special handling for quantized Mixtral.
  19. # FIXME: This is a temporary hack.
  20. if (model_config.quantization is not None
  21. and model_config.quantization != "fp8"
  22. and "MixtralForCausalLM" in architectures):
  23. architectures = ["QuantMixtralForCausalLM"]
  24. for arch in architectures:
  25. model_cls = ModelRegistry.load_model_cls(arch)
  26. if model_cls is not None:
  27. return (model_cls, arch)
  28. raise ValueError(
  29. f"Model architectures {architectures} are not supported for now. "
  30. f"Supported architectures: {ModelRegistry.get_supported_archs()}")
  31. def get_architecture_class_name(model_config: ModelConfig) -> str:
  32. return get_model_architecture(model_config)[1]