123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- # ruff: noqa: SIM117
- import collections
- import copy
- import glob
- import os
- from abc import ABC, abstractmethod
- from typing import Any, Dict, Generator, List, Optional, Tuple, Type
- import huggingface_hub
- import torch
- from loguru import logger
- from torch import nn
- from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
- DeviceConfig, LoadConfig, LoadFormat,
- LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig, VisionLanguageConfig)
- from aphrodite.modeling.model_loader.tensorizer import (
- TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
- tensorizer_weights_iterator)
- from aphrodite.modeling.model_loader.utils import (get_model_architecture,
- set_default_torch_dtype)
- from aphrodite.modeling.model_loader.weight_utils import (
- download_safetensors_index_file_from_hf, download_weights_from_hf,
- filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
- get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
- pt_weights_iterator, safetensors_weights_iterator)
- from aphrodite.modeling.models.vlm_base import VisionLanguageModelBase
- from aphrodite.quantization.base_config import QuantizationConfig
- def _get_quantization_config(
- model_config: ModelConfig,
- load_config: LoadConfig) -> Optional[QuantizationConfig]:
- """Get the quantization config."""
- if model_config.quantization is not None:
- quant_config = get_quant_config(model_config, load_config)
- capability = torch.cuda.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- if capability < quant_config.get_min_capability():
- raise ValueError(
- f"The quantization method {model_config.quantization} is not "
- "supported for the current GPU. "
- f"Minimum capability: {quant_config.get_min_capability()}. "
- f"Current capability: {capability}.")
- supported_dtypes = quant_config.get_supported_act_dtypes()
- if model_config.dtype not in supported_dtypes:
- raise ValueError(
- f"{model_config.dtype} is not supported for quantization "
- f"method {model_config.quantization}. Supported dtypes: "
- f"{supported_dtypes}")
- return quant_config
- return None
- def _get_model_initialization_kwargs(
- model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig]
- ) -> Dict[str, Any]:
- """Get extra kwargs for model initialization."""
- extra_kwargs = {}
- if hasattr(model_class, "supported_lora_modules"):
- extra_kwargs["lora_config"] = lora_config
- elif lora_config:
- raise ValueError(
- f"Model {model_class.__name__} does not support LoRA, "
- "but LoRA is enabled. Support for this model may "
- "be added in the future. If this is important to you, "
- "please open an issue on github.")
- elif issubclass(model_class, VisionLanguageModelBase):
- if vision_language_config is None:
- raise ValueError("Provide `image_input_type` and other vision "
- "related configurations through LLM entrypoint "
- "or engine arguments.")
- extra_kwargs["vision_language_config"] = vision_language_config
- return extra_kwargs
- def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- cache_config: CacheConfig) -> nn.Module:
- """Initialize a model with the given configurations."""
- model_class = get_model_architecture(model_config)[0]
- quant_config = _get_quantization_config(model_config, load_config)
- return model_class(config=model_config.hf_config,
- cache_config=cache_config,
- quant_config=quant_config,
- **_get_model_initialization_kwargs(
- model_class, lora_config, vision_language_config))
- class BaseModelLoader(ABC):
- """Base class for model loaders."""
- def __init__(self, load_config: LoadConfig):
- self.load_config = load_config
- @abstractmethod
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- """Load a model with the given configurations."""
- ...
- class DefaultModelLoader(BaseModelLoader):
- """Model loader that can load different file types from disk."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def _maybe_download_from_modelscope(
- self, model: str, revision: Optional[str]) -> Optional[str]:
- """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
- True.
-
- Returns the path to the downloaded model, or None if the model is not
- downloaded from ModelScope."""
- if APHRODITE_USE_MODELSCOPE:
- # download model from ModelScope hub,
- # lazy import so that modelscope is not required for normal use.
- # pylint: disable=C.
- from modelscope.hub.snapshot_download import snapshot_download
- if not os.path.exists(model):
- model_path = snapshot_download(
- model_id=model,
- cache_dir=self.load_config.download_dir,
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- revision=revision,
- )
- else:
- model_path = model
- return model_path
- return None
- def _prepare_weights(self, model_name_or_path: str,
- revision: Optional[str],
- fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
- """Prepare weights for the model.
- If the model is not local, it will be downloaded."""
- model_name_or_path = self._maybe_download_from_modelscope(
- model_name_or_path, revision) or model_name_or_path
- is_local = os.path.isdir(model_name_or_path)
- load_format = self.load_config.load_format
- use_safetensors = False
- # Some quantized models use .pt files for storing the weights.
- if load_format == LoadFormat.AUTO:
- allow_patterns = ["*.safetensors", "*.bin"]
- elif load_format == LoadFormat.SAFETENSORS:
- use_safetensors = True
- allow_patterns = ["*.safetensors"]
- elif load_format == LoadFormat.PT:
- allow_patterns = ["*.pt"]
- elif load_format == LoadFormat.NPCACHE:
- allow_patterns = ["*.bin"]
- else:
- raise ValueError(f"Unknown load_format: {load_format}")
- if fall_back_to_pt:
- allow_patterns += ["*.pt"]
- if not is_local:
- hf_folder = download_weights_from_hf(model_name_or_path,
- self.load_config.download_dir,
- allow_patterns, revision)
- else:
- hf_folder = model_name_or_path
- hf_weights_files: List[str] = []
- for pattern in allow_patterns:
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
- if len(hf_weights_files) > 0:
- if pattern == "*.safetensors":
- use_safetensors = True
- break
- if use_safetensors:
- # For models like Mistral-7B-Instruct-v0.3
- # there are both sharded safetensors files and a consolidated
- # safetensors file. Using both breaks.
- # Here, we download the `model.safetensors.index.json` and filter
- # any files not found in the index.
- if not is_local:
- download_safetensors_index_file_from_hf(
- model_name_or_path, self.load_config.download_dir,
- revision)
- hf_weights_files = filter_duplicate_safetensors_files(
- hf_weights_files, hf_folder)
- else:
- hf_weights_files = filter_files_not_needed_for_inference(
- hf_weights_files)
- if len(hf_weights_files) == 0:
- raise RuntimeError(
- f"Cannot find any model weights with `{model_name_or_path}`")
- return hf_folder, hf_weights_files, use_safetensors
- def _get_weights_iterator(
- self, model_name_or_path: str, revision: Optional[str],
- fall_back_to_pt: bool
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """Get an iterator for the model weights based on the load format."""
- hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
- model_name_or_path, revision, fall_back_to_pt)
- if self.load_config.load_format == LoadFormat.NPCACHE:
- # Currently np_cache only support *.bin checkpoints
- assert use_safetensors is False
- return np_cache_weights_iterator(model_name_or_path,
- self.load_config.download_dir,
- hf_folder, hf_weights_files)
- if use_safetensors:
- return safetensors_weights_iterator(hf_weights_files)
- return pt_weights_iterator(hf_weights_files)
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config,
- cache_config)
- model.load_weights(
- self._get_weights_iterator(model_config.model,
- model_config.revision,
- fall_back_to_pt=getattr(
- model,
- "fall_back_to_pt_during_load",
- True)), )
- for _, module in model.named_modules():
- quant_method = getattr(module, "quant_method", None)
- if quant_method is not None:
- quant_method.process_weights_after_loading(module)
- # FIXME: Remove this after Mixtral is updated
- # to use quant_method.
- if hasattr(module, "process_weights_after_loading"):
- module.process_weights_after_loading()
- return model.eval()
- class DummyModelLoader(BaseModelLoader):
- """Model loader that will set model weights to random values."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config,
- cache_config)
- # NOTE: For accurate performance evaluation, we assign
- # random values to the weights.
- initialize_dummy_weights(model)
- return model.eval()
- class TensorizerLoader(BaseModelLoader):
- """Model loader using CoreWeave's tensorizer library."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
- self.tensorizer_config = load_config.model_loader_extra_config
- else:
- self.tensorizer_config = TensorizerConfig(
- **load_config.model_loader_extra_config)
- def _verify_config(self, model_config: ModelConfig,
- parallel_config: ParallelConfig):
- self.tensorizer_config.verify_with_model_config(model_config)
- self.tensorizer_config.verify_with_parallel_config(parallel_config)
- def _get_weights_iterator(
- self) -> Generator[Tuple[str, torch.Tensor], None, None]:
- tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
- return tensorizer_weights_iterator(tensorizer_args)
- def _load_model_serialized_cpu(
- self,
- model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- cache_config: CacheConfig,
- ) -> nn.Module:
- """Load a serialized model with tensorizer to the CPU.
- This is only necessary when the model isn't Aphrodite-tensorized (see
- examples/tensorize_aphrodite_model.py) This should still be faster than
- default HuggingFace loading, but will be slower than loading a
- Aphrodite-tensorized model.
- """
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config,
- cache_config)
- model.load_weights(self._get_weights_iterator())
- return model.eval()
- def _load_model_serialized(
- self,
- model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- cache_config: CacheConfig,
- ) -> nn.Module:
- """Load a serialized model with tensorizer.
- Expects a Aphrodite-tensorized model. See the
- examples/tensorize_aphrodite_model.py example script
- for serializing Aphrodite models."""
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model_class = get_model_architecture(model_config)[0]
- quant_config = _get_quantization_config(
- model_config, self.load_config)
- extra_kwargs = _get_model_initialization_kwargs(
- model_class, lora_config, vision_language_config)
- extra_kwargs["quant_config"] = quant_config
- extra_kwargs["cache_config"] = cache_config
- tensorizer_config = copy.copy(self.tensorizer_config)
- tensorizer_config.model_class = model_class
- tensorizer_config.hf_config = model_config.hf_config
- tensorizer_config.dtype = model_config.dtype
- model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
- return model.eval()
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- self._verify_config(model_config, parallel_config)
- if is_aphrodite_tensorized(self.tensorizer_config):
- return self._load_model_serialized(model_config, device_config,
- lora_config,
- vision_language_config,
- cache_config)
- return self._load_model_serialized_cpu(model_config, device_config,
- lora_config,
- vision_language_config,
- cache_config)
- class ShardedStateLoader(BaseModelLoader):
- """
- Model loader that directly loads each worker's model state dict, which
- enables a fast load path for large tensor-parallel models where each worker
- only needs to read its own shard rather than the entire checkpoint. See
- `examples/save_sharded_states.py` for creating a sharded checkpoint.
- """
- DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- extra_config = ({} if load_config.model_loader_extra_config is None
- else load_config.model_loader_extra_config.copy())
- self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
- if extra_config:
- raise ValueError(f"Unexpected extra config keys for load format "
- f"{load_config.load_format}: "
- f"{load_config.model_loader_extra_config.keys()}")
- @staticmethod
- def _filter_subtensors(
- tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- """
- Filter out all tensors that share the same memory or a subset of the
- memory of another tensor.
- """
- same_storage_groups = collections.defaultdict(list)
- for key, tensor in tensors.items():
- if tensor.numel():
- ptr = tensor.untyped_storage().data_ptr()
- same_storage_groups[tensor.device, ptr].append((key, tensor))
- def get_end_ptr(tensor: torch.Tensor) -> int:
- return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
- result = {}
- for group in same_storage_groups.values():
- for k, t in group:
- a, b = t.data_ptr(), get_end_ptr(t)
- for k2, t2 in group:
- if not t2.is_contiguous():
- continue
- a2, b2 = t2.data_ptr(), get_end_ptr(t2)
- if a < a2 or b2 < b:
- continue
- if a2 < a or b < b2 or not t.is_contiguous():
- break # t2 covers strictly more memory than t.
- if k2 < k:
- # Same tensors, keep the one with the smaller key.
- break
- else:
- result[k] = t
- return result
- def _prepare_weights(self, model_name_or_path: str,
- revision: Optional[str]):
- if os.path.isdir(model_name_or_path):
- return model_name_or_path
- else:
- allow_patterns = ["*.safetensors"]
- return download_weights_from_hf(model_name_or_path,
- self.load_config.download_dir,
- allow_patterns, revision)
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- from safetensors.torch import safe_open
- from aphrodite.distributed import get_tensor_model_parallel_rank
- local_model_path = self._prepare_weights(model_config.model,
- model_config.revision)
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, vision_language_config,
- cache_config)
- rank = get_tensor_model_parallel_rank()
- pattern = os.path.join(
- local_model_path,
- self.pattern.format(rank=rank, part="*"),
- )
- filepaths = glob.glob(pattern)
- if not filepaths:
- # TODO: support un-sharded checkpoints too
- raise ValueError(
- f"Could not find checkpoint files '{pattern}', only "
- f"pre-sharded checkpoints are currently supported!")
- state_dict = self._filter_subtensors(model.state_dict())
- for path in filepaths:
- with safe_open(path, framework="pt") as f:
- for key in f.keys(): # noqa: SIM118
- tensor = f.get_tensor(key)
- # If loading with LoRA enabled, additional padding may
- # be added to certain parameters. We only load into a
- # narrowed view of the parameter data.
- param_data = state_dict[key].data
- param_shape = state_dict[key].shape
- for dim, size in enumerate(tensor.shape):
- if size < param_shape[dim]:
- param_data = param_data.narrow(dim, 0, size)
- if tensor.shape != param_shape:
- logger.warning("loading tensor of shape "
- f"{tensor.shape} into parameter "
- f"'{key}' of shape {param_shape}")
- param_data.copy_(tensor)
- state_dict.pop(key)
- if state_dict:
- raise ValueError(
- f"Missing keys {tuple(state_dict)} in loaded state!")
- return model.eval()
- @staticmethod
- def save_model(
- model: torch.nn.Module,
- path: str,
- pattern: Optional[str] = None,
- max_size: Optional[int] = None,
- ) -> None:
- from safetensors.torch import save_file
- from aphrodite.distributed import get_tensor_model_parallel_rank
- if pattern is None:
- pattern = ShardedStateLoader.DEFAULT_PATTERN
- rank = get_tensor_model_parallel_rank()
- part_idx = 0
- total_size = 0
- state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
- state_dict_part: Dict[str, torch.Tensor] = {}
- for key, tensor in state_dict.items():
- param_size = tensor.nelement() * tensor.element_size()
- if max_size is not None and total_size + param_size > max_size:
- filename = pattern.format(rank=rank, part=part_idx)
- save_file(
- state_dict_part,
- os.path.join(path, filename),
- )
- part_idx += 1
- total_size = 0
- state_dict_part = {}
- state_dict_part[key] = tensor
- total_size += param_size
- if len(state_dict_part) > 0:
- filename = pattern.format(rank=rank, part=part_idx)
- save_file(
- state_dict_part,
- os.path.join(path, filename),
- )
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
- """Get a model loader based on the load format."""
- if isinstance(load_config.load_format, type):
- return load_config.load_format(load_config)
- if load_config.load_format == LoadFormat.DUMMY:
- return DummyModelLoader(load_config)
- if load_config.load_format == LoadFormat.TENSORIZER:
- return TensorizerLoader(load_config)
- if load_config.load_format == LoadFormat.SHARDED_STATE:
- return ShardedStateLoader(load_config)
- return DefaultModelLoader(load_config)
|