123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- import contextlib
- from typing import Type
- import torch
- import torch.nn as nn
- from transformers import PretrainedConfig
- from aphrodite.common.config import ModelConfig
- from aphrodite.modeling.models import (LlamaForCausalLM, GPTJForCausalLM,
- GPTNeoXForCausalLM, MistralForCausalLM)
- from aphrodite.modeling.hf_downloader import (initialize_dummy_weights,
- get_quant_config)
- from aphrodite.modeling.layers.quantized_linear.utils import quant_post_init
- _MODEL_REGISTRY = {
- "LlamaForCausalLM": LlamaForCausalLM,
- "LLaMAForCausalLM": LlamaForCausalLM,
- "GPTJForCausalLM": GPTJForCausalLM,
- "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
- "MistralForCausalLM": MistralForCausalLM,
- }
- _MODEL_CLASSES_SUPPORT_QUANTIZATION = {
- "awq": [LlamaForCausalLM, MistralForCausalLM],
- "gptq": [
- LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM,
- MistralForCausalLM
- ],
- }
- @contextlib.contextmanager
- def _set_default_torch_dtype(dtype: torch.dtype):
- """Sets the default torch dtype to the given dtype."""
- old_dtype = torch.get_default_dtype()
- torch.set_default_dtype(dtype)
- yield
- torch.set_default_dtype(old_dtype)
- def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
- architectures = getattr(config, "architectures", [])
- for arch in architectures:
- if arch in _MODEL_REGISTRY:
- return _MODEL_REGISTRY[arch]
- raise ValueError(
- f"Model architectures {architectures} are not supported for now. "
- f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
- def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module:
- model_class = _get_model_architecture(model_config.hf_config)
- # Get the quantization config.
- quant_config = None
- if model_config.quantization is not None:
- if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION[
- model_config.quantization]:
- raise ValueError(
- f"Quantization is not supported for {model_class}.")
- quant_config = get_quant_config(model_config.quantization,
- model_config.model,
- model_config.hf_config,
- model_config.download_dir)
- capability = torch.cuda.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- if capability < quant_config.get_min_capability():
- raise ValueError(
- f"The quantization method {model_config.quantization} is not "
- "supported for the current GPU. "
- f"Minimum capability: {quant_config.get_min_capability()}. "
- f"Current capability: {capability}.")
- supported_dtypes = quant_config.get_supported_act_dtypes()
- if model_config.dtype not in supported_dtypes:
- raise ValueError(
- f"{model_config.dtype} is not supported for quantization "
- f"method {model_config.quantization}. Supported dtypes: "
- f"{supported_dtypes}")
- with _set_default_torch_dtype(model_config.dtype):
- # Create a model instance.
- # The weights will be initialized as empty tensors.
- if model_config.quantization is not None and (
- model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION[
- model_config.quantization]):
- model = model_class(model_config.hf_config, quant_config)
- else:
- model = model_class(model_config.hf_config)
- if model_config.load_format == "dummy":
- model = model.cuda()
- # NOTE(woosuk): For accurate performance evaluation, we assign
- # random values to the weights.
- initialize_dummy_weights(model)
- else:
- # Load the weights from the cached or downloaded files.
- model.load_weights(model_config.model, model_config.download_dir,
- model_config.load_format, model_config.revision)
- model = model.cuda()
- if model_config.quantization is not None:
- quant_post_init(model, max_tokens)
- return model.eval()
|