utils.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. """Utilities for selecting and loading models."""
  2. import contextlib
  3. from typing import Tuple, Type
  4. import torch
  5. from torch import nn
  6. from aphrodite.common.config import ModelConfig
  7. from aphrodite.modeling.models import ModelRegistry
  8. @contextlib.contextmanager
  9. def set_default_torch_dtype(dtype: torch.dtype):
  10. """Sets the default torch dtype to the given dtype."""
  11. old_dtype = torch.get_default_dtype()
  12. torch.set_default_dtype(dtype)
  13. yield
  14. torch.set_default_dtype(old_dtype)
  15. def get_model_architecture(
  16. model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
  17. architectures = getattr(model_config.hf_config, "architectures", [])
  18. # Special handling for quantized Mixtral.
  19. # FIXME: This is a temporary hack.
  20. if (model_config.quantization is not None
  21. and model_config.quantization != "fp8"
  22. and "MixtralForCausalLM" in architectures):
  23. architectures = ["QuantMixtralForCausalLM"]
  24. return ModelRegistry.resolve_model_cls(architectures)
  25. def get_architecture_class_name(model_config: ModelConfig) -> str:
  26. return get_model_architecture(model_config)[1]