1234567891011121314151617181920212223242526272829303132333435 |
- """Utilities for selecting and loading models."""
- import contextlib
- from typing import Tuple, Type
- import torch
- from torch import nn
- from aphrodite.common.config import ModelConfig
- from aphrodite.modeling.models import ModelRegistry
- @contextlib.contextmanager
- def set_default_torch_dtype(dtype: torch.dtype):
- """Sets the default torch dtype to the given dtype."""
- old_dtype = torch.get_default_dtype()
- torch.set_default_dtype(dtype)
- yield
- torch.set_default_dtype(old_dtype)
- def get_model_architecture(
- model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
- architectures = getattr(model_config.hf_config, "architectures", [])
- # Special handling for quantized Mixtral.
- # FIXME: This is a temporary hack.
- if (model_config.quantization is not None
- and model_config.quantization != "fp8"
- and "MixtralForCausalLM" in architectures):
- architectures = ["QuantMixtralForCausalLM"]
- return ModelRegistry.resolve_model_cls(architectures)
- def get_architecture_class_name(model_config: ModelConfig) -> str:
- return get_model_architecture(model_config)[1]
|