123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import contextlib
- from typing import Type
- import torch
- from torch.multiprocessing import Value
- import torch.nn as nn
- from transformers import PretrainedConfig
- from aphrodite.common.config import ModelConfig
- from aphrodite.modeling.models import LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM
- from aphrodite.modeling.hf_downloader import initialize_dummy_weights, get_quant_config
- _MODEL_REGISTRY = {
- "LlamaForCausalLM": LlamaForCausalLM,
- "LLaMAForCausalLM": LlamaForCausalLM,
- "GPTJForCausalLM": GPTJForCausalLM,
- "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
- }
- _QUANT_REGISTRY = {
- "LlamaForCausalLM",
- }
- @contextlib.contextmanager
- def _set_default_torch_dtype(dtype: torch.dtype):
- """Sets the default torch dtype to the given dtype."""
- old_dtype = torch.get_default_dtype()
- torch.set_default_dtype(dtype)
- yield
- torch.set_default_dtype(old_dtype)
- def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
- architectures = getattr(config, "architectures", [])
- for arch in architectures:
- if arch in _MODEL_REGISTRY:
- return _MODEL_REGISTRY[arch]
- raise ValueError(
- f"Model architectures {architectures} are not supported for now. "
- f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
- def get_model(model_config: ModelConfig) -> nn.Module:
- model_class = _get_model_architecture(model_config.hf_config)
- quant_config = None
- if model_config.quantization is not None:
- if model_class not in _QUANT_REGISTRY:
- raise ValueError(
- f"Quantization is not supported for {model_class}.")
- quant_config = get_quant_config(model_config.quantization,
- model_config.model,
- model_config.download_dir)
- supported_dtypes = quant_config.get_supported_act_dtypes()
- if model_config.dtype not in supported_dtypes:
- raise ValueError(
- f"{model_config.dtype} is not supported for quantization method {model_config.quantization}. "
- f"Supported datatypes: {supported_dtypes}")
- with _set_default_torch_dtype(model_config.dtype):
- # Create a model instance.
- # The weights will be initialized as empty tensors.
- if model_class in _QUANT_REGISTRY:
- model = model_class(model_config.hf_config, quant_config)
- else:
- model = model_class(model_config.hf_config)
- if model_config.load_format == "dummy":
- model = model.cuda()
- # NOTE: For accurate performance evaluation, we assign
- # random values to the weights.
- initialize_dummy_weights(model)
- else:
- # Load the weights from the cached or downloaded files.
- model.load_weights(model_config.model, model_config.download_dir,
- model_config.load_format)
- model = model.cuda()
- return model.eval()
|