123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- # ruff: noqa: SIM117
- import copy
- import glob
- import os
- from abc import ABC, abstractmethod
- import gc
- from contextlib import nullcontext
- from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
- Type)
- import torch
- from torch import nn
- from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, DeviceConfig,
- LoadConfig, LoadFormat, LoRAConfig,
- ModelConfig, ParallelConfig,
- SchedulerConfig, VisionLanguageConfig)
- from aphrodite.modeling.model_loader.tensorizer import (
- TensorizerConfig, is_aphrodite_serialized_tensorizer, load_with_tensorizer,
- tensorizer_weights_iterator)
- from aphrodite.modeling.model_loader.utils import (get_model_architecture,
- set_default_torch_dtype)
- from aphrodite.modeling.model_loader.weight_utils import (
- download_weights_from_hf, filter_files_not_needed_for_inference,
- get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
- pt_weights_iterator, safetensors_weights_iterator)
- from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
- from aphrodite.quantization.bitsandbytes import (BNBLinearMethod,
- replace_quant_params)
- if TYPE_CHECKING:
- from aphrodite.modeling.layers.linear import LinearMethodBase
- _VISION_MODEL_CLASSES = [
- LlavaForConditionalGeneration,
- ]
- def _get_linear_method(
- model_config: ModelConfig,
- load_config: LoadConfig) -> Optional["LinearMethodBase"]:
- """Get the (maybe quantized) linear method."""
- linear_method = None
- if model_config.quantization is not None:
- quant_config = get_quant_config(model_config, load_config)
- capability = torch.cuda.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- if capability < quant_config.get_min_capability():
- raise ValueError(
- f"The quantization method {model_config.quantization} is not "
- "supported for the current GPU. "
- f"Minimum capability: {quant_config.get_min_capability()}. "
- f"Current capability: {capability}.")
- supported_dtypes = quant_config.get_supported_act_dtypes()
- if model_config.dtype not in supported_dtypes:
- raise ValueError(
- f"{model_config.dtype} is not supported for quantization "
- f"method {model_config.quantization}. Supported dtypes: "
- f"{supported_dtypes}")
- linear_method = quant_config.get_linear_method()
- return linear_method
- def _get_model_initialization_kwargs(
- model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig]
- ) -> Dict[str, Any]:
- """Get extra kwargs for model initialization."""
- extra_kwargs = {}
- if hasattr(model_class, "supported_lora_modules"):
- extra_kwargs["lora_config"] = lora_config
- elif lora_config:
- raise ValueError(
- f"Model {model_class.__name__} does not support LoRA, "
- "but LoRA is enabled. Support for this model may "
- "be added in the future. If this is important to you, "
- "please open an issue on github.")
- elif model_class in _VISION_MODEL_CLASSES:
- extra_kwargs["vision_language_config"] = vision_language_config
- return extra_kwargs
- def _initialize_model(
- model_config: ModelConfig, load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
- """Initialize a model with the given configurations."""
- model_class = get_model_architecture(model_config)[0]
- linear_method = _get_linear_method(model_config, load_config)
- return model_class(config=model_config.hf_config,
- linear_method=linear_method,
- **_get_model_initialization_kwargs(
- model_class, lora_config, vision_language_config))
- class BaseModelLoader(ABC):
- """Base class for model loaders."""
- def __init__(self, load_config: LoadConfig):
- self.load_config = load_config
- @abstractmethod
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig) -> nn.Module:
- """Load a model with the given configurations."""
- ...
- class DefaultModelLoader(BaseModelLoader):
- """Model loader that can load different file types from disk."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def _maybe_download_from_modelscope(
- self, model: str, revision: Optional[str]) -> Optional[str]:
- """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
- True.
-
- Returns the path to the downloaded model, or None if the model is not
- downloaded from ModelScope."""
- if APHRODITE_USE_MODELSCOPE:
- # download model from ModelScope hub,
- # lazy import so that modelscope is not required for normal use.
- # pylint: disable=C.
- from modelscope.hub.snapshot_download import snapshot_download
- if not os.path.exists(model):
- model_path = snapshot_download(
- model_id=model,
- cache_dir=self.load_config.download_dir,
- revision=revision)
- else:
- model_path = model
- return model_path
- return None
- def _prepare_weights(self, model_name_or_path: str,
- revision: Optional[str],
- fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
- """Prepare weights for the model.
- If the model is not local, it will be downloaded."""
- model_name_or_path = self._maybe_download_from_modelscope(
- model_name_or_path, revision) or model_name_or_path
- is_local = os.path.isdir(model_name_or_path)
- load_format = self.load_config.load_format
- use_safetensors = False
- # Some quantized models use .pt files for storing the weights.
- if load_format == LoadFormat.AUTO:
- allow_patterns = ["*.safetensors", "*.bin"]
- elif load_format == LoadFormat.SAFETENSORS:
- use_safetensors = True
- allow_patterns = ["*.safetensors"]
- elif load_format == LoadFormat.PT:
- allow_patterns = ["*.pt"]
- elif load_format == LoadFormat.NPCACHE:
- allow_patterns = ["*.bin"]
- else:
- raise ValueError(f"Unknown load_format: {load_format}")
- if fall_back_to_pt:
- allow_patterns += ["*.pt"]
- if not is_local:
- hf_folder = download_weights_from_hf(model_name_or_path,
- self.load_config.download_dir,
- allow_patterns, revision)
- else:
- hf_folder = model_name_or_path
- hf_weights_files: List[str] = []
- for pattern in allow_patterns:
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
- if len(hf_weights_files) > 0:
- if pattern == "*.safetensors":
- use_safetensors = True
- break
- if not use_safetensors:
- hf_weights_files = filter_files_not_needed_for_inference(
- hf_weights_files)
- if len(hf_weights_files) == 0:
- raise RuntimeError(
- f"Cannot find any model weights with `{model_name_or_path}`")
- return hf_folder, hf_weights_files, use_safetensors
- def _get_weights_iterator(
- self, model_name_or_path: str, revision: Optional[str],
- fall_back_to_pt: bool
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """Get an iterator for the model weights based on the load format."""
- hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
- model_name_or_path, revision, fall_back_to_pt)
- if self.load_config.load_format == LoadFormat.NPCACHE:
- # Currently np_cache only support *.bin checkpoints
- assert use_safetensors is False
- return np_cache_weights_iterator(model_name_or_path,
- self.load_config.download_dir,
- hf_folder, hf_weights_files)
- if use_safetensors:
- return safetensors_weights_iterator(hf_weights_files)
- return pt_weights_iterator(hf_weights_files)
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig) -> nn.Module:
- with set_default_torch_dtype(model_config.dtype):
- linear_method = _get_linear_method(model_config, self.load_config)
- context = torch.device(device_config.device) if not (
- isinstance(linear_method, BNBLinearMethod)
- and linear_method.quant_config.from_float) else nullcontext()
- with context:
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config)
- model.load_weights(
- self._get_weights_iterator(model_config.model,
- model_config.revision,
- fall_back_to_pt=getattr(
- model,
- "fall_back_to_pt_during_load",
- True)), )
- for _, module in model.named_modules():
- linear_method = getattr(module, "linear_method", None)
- if linear_method is not None:
- linear_method.process_weights_after_loading(module)
- if hasattr(module, "process_weights_after_loading"):
- module.process_weights_after_loading()
- if isinstance(linear_method, BNBLinearMethod):
- replace_quant_params(
- model,
- quant_config=linear_method.quant_config,
- modules_to_not_convert="lm_head",
- )
- torch.cuda.synchronize()
- if linear_method.quant_config.from_float:
- model = model.cuda()
- gc.collect()
- torch.cuda.empty_cache()
- return model.eval()
- class DummyModelLoader(BaseModelLoader):
- """Model loader that will set model weights to random values."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig) -> nn.Module:
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config)
- # NOTE(woosuk): For accurate performance evaluation, we assign
- # random values to the weights.
- initialize_dummy_weights(model)
- return model.eval()
- class TensorizerLoader(BaseModelLoader):
- """Model loader using CoreWeave's tensorizer library."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
- self.tensorizer_config = load_config.model_loader_extra_config
- else:
- self.tensorizer_config = TensorizerConfig(
- **load_config.model_loader_extra_config)
- def _verify_config(self, model_config: ModelConfig,
- parallel_config: ParallelConfig):
- self.tensorizer_config.verify_with_model_config(model_config)
- self.tensorizer_config.verify_with_parallel_config(parallel_config)
- def _get_weights_iterator(
- self) -> Generator[Tuple[str, torch.Tensor], None, None]:
- tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
- return tensorizer_weights_iterator(tensorizer_args)
- def _load_model_unserialized(
- self, model_config: ModelConfig, device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig]
- ) -> nn.Module:
- """Load an unserialized model with tensorizer.
- Unserialized here means "not serialized with tensorizer". This
- should still be faster than default HuggingFace loading, but will
- be slower than loading a tensorizer-serialized model.
- """
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config)
- model.load_weights(self._get_weights_iterator())
- return model.eval()
- def _load_model_serialized(
- self, model_config: ModelConfig, device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig]
- ) -> nn.Module:
- """Load a serialized model with tensorizer.
- See the examples/tensorize_aphrodite_model.py example "
- script for serializing Aphrodite models."""
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model_class = get_model_architecture(model_config)[0]
- linear_method = _get_linear_method(model_config,
- self.load_config)
- extra_kwargs = _get_model_initialization_kwargs(
- model_class, lora_config, vision_language_config)
- extra_kwargs["linear_method"] = linear_method
- tensorizer_config = copy.copy(self.tensorizer_config)
- tensorizer_config.model_class = model_class
- tensorizer_config.hf_config = model_config.hf_config
- tensorizer_config.dtype = model_config.dtype
- model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
- return model.eval()
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig) -> nn.Module:
- self._verify_config(model_config, parallel_config)
- if is_aphrodite_serialized_tensorizer(self.tensorizer_config):
- return self._load_model_serialized(model_config, device_config,
- lora_config,
- vision_language_config)
- return self._load_model_unserialized(model_config, device_config,
- lora_config,
- vision_language_config)
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
- """Get a model loader based on the load format."""
- if isinstance(load_config.load_format, type):
- return load_config.load_format(load_config)
- if load_config.load_format == LoadFormat.DUMMY:
- return DummyModelLoader(load_config)
- if load_config.load_format == LoadFormat.TENSORIZER:
- return TensorizerLoader(load_config)
- return DefaultModelLoader(load_config)
|