12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055 |
- # ruff: noqa: SIM117
- import collections
- import copy
- import fnmatch
- import glob
- import json
- import math
- import os
- from abc import ABC, abstractmethod
- from contextlib import contextmanager
- from typing import Any, Dict, Generator, List, Optional, Tuple, Type
- import gguf
- import huggingface_hub
- import numpy as np
- import torch
- from huggingface_hub import HfApi, hf_hub_download
- from loguru import logger
- from torch import nn
- from transformers import AutoModelForCausalLM, PretrainedConfig
- from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
- DeviceConfig, LoadConfig, LoadFormat,
- LoRAConfig, ModelConfig, MultiModalConfig,
- ParallelConfig, SchedulerConfig)
- from aphrodite.common.utils import is_pin_memory_available
- from aphrodite.modeling.model_loader.tensorizer import (
- TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
- serialize_aphrodite_model, tensorizer_weights_iterator)
- from aphrodite.modeling.model_loader.utils import (get_model_architecture,
- set_default_torch_dtype)
- from aphrodite.modeling.model_loader.weight_utils import (
- download_safetensors_index_file_from_hf, download_weights_from_hf,
- filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
- get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
- initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
- safetensors_weights_iterator)
- from aphrodite.modeling.models.interfaces import (has_inner_state,
- supports_lora,
- supports_multimodal)
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.platforms import current_platform
- from aphrodite.quantization.base_config import QuantizationConfig
- @contextmanager
- def device_loading_context(module: torch.nn.Module,
- target_device: torch.device):
- if target_device.type == "cpu":
- # If target is CPU, no need to move anything
- yield module
- return
- original_device_states: Dict[str, torch.device] = {}
- # Store original device states and move parameters to GPU if they're on CPU
- for name, p in module.named_parameters():
- if p.device.type == "cpu":
- original_device_states[name] = p.device
- p.data = p.data.to(target_device)
- # Parameters already on target device are not touched
- try:
- yield module
- finally:
- # Restore parameters to their original devices, ignoring new parameters
- pin_memory = is_pin_memory_available()
- for name, p in module.named_parameters():
- if name in original_device_states:
- original_device: torch.device = original_device_states[name]
- if original_device.type == "cpu":
- # `torch.empty_like` does not support `pin_memory` argument
- cpu_data = torch.empty_strided(size=p.data.size(),
- stride=p.data.stride(),
- dtype=p.data.dtype,
- layout=p.data.layout,
- device="cpu",
- pin_memory=pin_memory)
- cpu_data.copy_(p.data)
- p.data = cpu_data
- else:
- p.data = p.data.to(original_device)
- # New parameters or parameters already on target device are untouched
- def _get_quantization_config(
- model_config: ModelConfig,
- load_config: LoadConfig) -> Optional[QuantizationConfig]:
- """Get the quantization config."""
- if model_config.quantization is not None:
- quant_config = get_quant_config(model_config, load_config)
- if not current_platform.is_tpu():
- capability = current_platform.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- if capability < quant_config.get_min_capability():
- raise ValueError(
- f"The quantization method {model_config.quantization} "
- "is not supported for the current GPU. "
- f"Minimum capability: {quant_config.get_min_capability()}. "
- f"Current capability: {capability}.")
- supported_dtypes = quant_config.get_supported_act_dtypes()
- if model_config.dtype not in supported_dtypes:
- raise ValueError(
- f"{model_config.dtype} is not supported for quantization "
- f"method {model_config.quantization}. Supported dtypes: "
- f"{supported_dtypes}")
- return quant_config
- return None
- def _get_model_initialization_kwargs(
- model_class: Type[nn.Module],
- lora_config: Optional[LoRAConfig],
- multimodal_config: Optional[MultiModalConfig],
- scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
- """Get extra kwargs for model initialization."""
- extra_kwargs: Dict[str, Any] = {}
- if supports_lora(model_class):
- # lora_config=None is used to disable LoRA
- extra_kwargs["lora_config"] = lora_config
- elif lora_config:
- raise ValueError(
- f"Model {model_class.__name__} does not support LoRA, "
- "but LoRA is enabled. Support for this model may "
- "be added in the future. If this is important to you, "
- "please open an issue on github.")
- if supports_multimodal(model_class):
- assert multimodal_config is not None
- extra_kwargs["multimodal_config"] = multimodal_config
- if has_inner_state(model_class) and scheduler_config:
- extra_kwargs["scheduler_config"] = scheduler_config
- return extra_kwargs
- def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
- cache_config: Optional[CacheConfig],
- quant_config: Optional[QuantizationConfig], *,
- lora_config: Optional[LoRAConfig],
- multimodal_config: Optional[MultiModalConfig],
- scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
- extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
- multimodal_config,
- scheduler_config)
- return model_class(config=hf_config,
- cache_config=cache_config,
- quant_config=quant_config,
- **extra_kwargs)
- def _initialize_model(
- model_config: ModelConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- cache_config: CacheConfig,
- scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
- """Initialize a model with the given configurations."""
- model_class, _ = get_model_architecture(model_config)
- return build_model(
- model_class,
- model_config.hf_config,
- cache_config=cache_config,
- quant_config=_get_quantization_config(model_config, load_config),
- lora_config=lora_config,
- multimodal_config=model_config.multimodal_config,
- scheduler_config=scheduler_config,
- )
- class BaseModelLoader(ABC):
- """Base class for model loaders."""
- def __init__(self, load_config: LoadConfig):
- self.load_config = load_config
- @abstractmethod
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- """Load a model with the given configurations."""
- ...
- class DefaultModelLoader(BaseModelLoader):
- """Model loader that can load different file types from disk."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def _maybe_download_from_modelscope(
- self, model: str, revision: Optional[str]) -> Optional[str]:
- """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
- True.
- Returns the path to the downloaded model, or None if the model is not
- downloaded from ModelScope."""
- if APHRODITE_USE_MODELSCOPE:
- # download model from ModelScope hub,
- # lazy import so that modelscope is not required for normal use.
- # pylint: disable=C.
- from modelscope.hub.snapshot_download import snapshot_download
- if not os.path.exists(model):
- model_path = snapshot_download(
- model_id=model,
- cache_dir=self.load_config.download_dir,
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
- revision=revision,
- ignore_file_pattern=self.load_config.ignore_patterns,
- )
- else:
- model_path = model
- return model_path
- return None
- def _prepare_weights(self, model_name_or_path: str,
- revision: Optional[str],
- fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
- """Prepare weights for the model.
- If the model is not local, it will be downloaded."""
- model_name_or_path = self._maybe_download_from_modelscope(
- model_name_or_path, revision) or model_name_or_path
- is_local = os.path.isdir(model_name_or_path)
- load_format = self.load_config.load_format
- use_safetensors = False
- # Some quantized models use .pt files for storing the weights.
- if load_format == LoadFormat.AUTO:
- allow_patterns = ["*.safetensors", "*.bin"]
- elif load_format == LoadFormat.SAFETENSORS:
- use_safetensors = True
- allow_patterns = ["*.safetensors"]
- elif load_format == LoadFormat.PT:
- allow_patterns = ["*.pt"]
- elif load_format == LoadFormat.NPCACHE:
- allow_patterns = ["*.bin"]
- else:
- raise ValueError(f"Unknown load_format: {load_format}")
- if fall_back_to_pt:
- allow_patterns += ["*.pt"]
- if not is_local:
- hf_folder = download_weights_from_hf(
- model_name_or_path,
- self.load_config.download_dir,
- allow_patterns,
- revision,
- ignore_patterns=self.load_config.ignore_patterns,
- )
- else:
- hf_folder = model_name_or_path
- hf_weights_files: List[str] = []
- for pattern in allow_patterns:
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
- if len(hf_weights_files) > 0:
- if pattern == "*.safetensors":
- use_safetensors = True
- break
- if use_safetensors:
- # For models like Mistral-7B-Instruct-v0.3
- # there are both sharded safetensors files and a consolidated
- # safetensors file. Using both breaks.
- # Here, we download the `model.safetensors.index.json` and filter
- # any files not found in the index.
- if not is_local:
- download_safetensors_index_file_from_hf(
- model_name_or_path, self.load_config.download_dir,
- revision)
- hf_weights_files = filter_duplicate_safetensors_files(
- hf_weights_files, hf_folder)
- else:
- hf_weights_files = filter_files_not_needed_for_inference(
- hf_weights_files)
- if len(hf_weights_files) == 0:
- raise RuntimeError(
- f"Cannot find any model weights with `{model_name_or_path}`")
- return hf_folder, hf_weights_files, use_safetensors
- def _get_weights_iterator(
- self, model_name_or_path: str, revision: Optional[str],
- fall_back_to_pt: bool
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- """Get an iterator for the model weights based on the load format."""
- hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
- model_name_or_path, revision, fall_back_to_pt)
- if self.load_config.load_format == LoadFormat.NPCACHE:
- # Currently np_cache only support *.bin checkpoints
- assert use_safetensors is False
- weights_iterator = np_cache_weights_iterator(
- model_name_or_path, self.load_config.download_dir, hf_folder,
- hf_weights_files)
- elif use_safetensors:
- weights_iterator = safetensors_weights_iterator(hf_weights_files)
- else:
- weights_iterator = pt_weights_iterator(hf_weights_files)
- if current_platform.is_tpu():
- # In PyTorch XLA, we should call `xm.mark_step` frequently so that
- # not too many ops are accumulated in the XLA program.
- import torch_xla.core.xla_model as xm
- def _xla_weights_iterator(iterator: Generator):
- for weights in iterator:
- yield weights
- xm.mark_step()
- weights_iterator = _xla_weights_iterator(weights_iterator)
- return weights_iterator
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- target_device = torch.device(device_config.device)
- with set_default_torch_dtype(model_config.dtype):
- with target_device:
- model = _initialize_model(model_config, self.load_config,
- lora_config, cache_config,
- scheduler_config)
- model.load_weights(
- self._get_weights_iterator(model_config.model,
- model_config.revision,
- fall_back_to_pt=getattr(
- model,
- "fall_back_to_pt_during_load",
- True)), )
- for _, module in model.named_modules():
- quant_method = getattr(module, "quant_method", None)
- if quant_method is not None:
- # When quant methods need to process weights after loading
- # (for repacking, quantizing, etc), they expect parameters
- # to be on the global target device. This scope is for the
- # case where cpu offloading is used, where we will move the
- # parameters onto device for processing and back off after.
- with device_loading_context(module, target_device):
- quant_method.process_weights_after_loading(module)
- return model.eval()
- class DummyModelLoader(BaseModelLoader):
- """Model loader that will set model weights to random values."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, cache_config,
- scheduler_config)
- # NOTE: For accurate performance evaluation, we assign
- # random values to the weights.
- initialize_dummy_weights(model)
- return model.eval()
- class TensorizerLoader(BaseModelLoader):
- """Model loader using CoreWeave's tensorizer library."""
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
- self.tensorizer_config = load_config.model_loader_extra_config
- else:
- self.tensorizer_config = TensorizerConfig(
- **load_config.model_loader_extra_config)
- def _verify_config(self, model_config: ModelConfig,
- parallel_config: ParallelConfig):
- self.tensorizer_config.verify_with_model_config(model_config)
- self.tensorizer_config.verify_with_parallel_config(parallel_config)
- def _get_weights_iterator(
- self) -> Generator[Tuple[str, torch.Tensor], None, None]:
- tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
- return tensorizer_weights_iterator(tensorizer_args)
- def _load_model_serialized_cpu(
- self,
- model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- cache_config: CacheConfig,
- ) -> nn.Module:
- """Load a serialized model with tensorizer to the CPU.
- This is only necessary when the model isn't Aphrodite-tensorized (see
- examples/tensorize_aphrodite_model.py) This should still be faster than
- default HuggingFace loading, but will be slower than loading a
- Aphrodite-tensorized model.
- """
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, cache_config)
- model.load_weights(self._get_weights_iterator())
- return model.eval()
- def _load_model_serialized(
- self,
- model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- cache_config: CacheConfig,
- ) -> nn.Module:
- """Load a serialized model with tensorizer.
- Expects a Aphrodite-tensorized model. See the
- examples/tensorize_aphrodite_model.py example script
- for serializing Aphrodite models."""
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model_class = get_model_architecture(model_config)[0]
- quant_config = _get_quantization_config(
- model_config, self.load_config)
- extra_kwargs = _get_model_initialization_kwargs(
- model_class, lora_config, model_config.multimodal_config)
- extra_kwargs["quant_config"] = quant_config
- extra_kwargs["cache_config"] = cache_config
- tensorizer_config = copy.copy(self.tensorizer_config)
- tensorizer_config.model_class = model_class
- tensorizer_config.hf_config = model_config.hf_config
- tensorizer_config.dtype = model_config.dtype
- model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
- return model.eval()
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- self._verify_config(model_config, parallel_config)
- if parallel_config.tensor_parallel_size > 1:
- from aphrodite.distributed import get_tensor_model_parallel_rank
- self.tensorizer_config.tensorizer_uri = \
- self.tensorizer_config.tensorizer_uri \
- % get_tensor_model_parallel_rank()
- if is_aphrodite_tensorized(self.tensorizer_config):
- return self._load_model_serialized(model_config, device_config,
- lora_config, cache_config)
- return self._load_model_serialized_cpu(model_config, device_config,
- lora_config, cache_config)
- @staticmethod
- def save_model(
- model: torch.nn.Module,
- tensorizer_config: TensorizerConfig,
- ) -> None:
- serialize_aphrodite_model(
- model=model,
- tensorizer_config=tensorizer_config,
- )
- class ShardedStateLoader(BaseModelLoader):
- """
- Model loader that directly loads each worker's model state dict, which
- enables a fast load path for large tensor-parallel models where each worker
- only needs to read its own shard rather than the entire checkpoint. See
- `examples/save_sharded_state.py` for creating a sharded checkpoint.
- """
- DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- extra_config = ({} if load_config.model_loader_extra_config is None
- else load_config.model_loader_extra_config.copy())
- self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
- if extra_config:
- raise ValueError(f"Unexpected extra config keys for load format "
- f"{load_config.load_format}: "
- f"{load_config.model_loader_extra_config.keys()}")
- @staticmethod
- def _filter_subtensors(
- tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
- """
- Filter out all tensors that share the same memory or a subset of the
- memory of another tensor.
- """
- same_storage_groups: Dict[Any, List[Tuple[
- str, torch.Tensor]]] = collections.defaultdict(list)
- for key, tensor in tensors.items():
- if tensor.numel():
- ptr = tensor.untyped_storage().data_ptr()
- same_storage_groups[tensor.device, ptr].append((key, tensor))
- def get_end_ptr(tensor: torch.Tensor) -> int:
- return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
- result: Dict[str, torch.Tensor] = {}
- for group in same_storage_groups.values():
- for k, t in group:
- a, b = t.data_ptr(), get_end_ptr(t)
- for k2, t2 in group:
- if not t2.is_contiguous():
- continue
- a2, b2 = t2.data_ptr(), get_end_ptr(t2)
- if a < a2 or b2 < b:
- continue
- if a2 < a or b < b2 or not t.is_contiguous():
- break # t2 covers strictly more memory than t.
- if k2 < k:
- # Same tensors, keep the one with the smaller key.
- break
- else:
- result[k] = t
- return result
- def _prepare_weights(self, model_name_or_path: str,
- revision: Optional[str]):
- if os.path.isdir(model_name_or_path):
- return model_name_or_path
- else:
- allow_patterns = ["*.safetensors"]
- return download_weights_from_hf(
- model_name_or_path,
- self.load_config.download_dir,
- allow_patterns,
- revision,
- ignore_patterns=self.load_config.ignore_patterns,
- )
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- from safetensors.torch import safe_open
- from aphrodite.distributed import get_tensor_model_parallel_rank
- local_model_path = self._prepare_weights(model_config.model,
- model_config.revision)
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, cache_config)
- rank = get_tensor_model_parallel_rank()
- pattern = os.path.join(
- local_model_path,
- self.pattern.format(rank=rank, part="*"),
- )
- filepaths = glob.glob(pattern)
- if not filepaths:
- # TODO: support un-sharded checkpoints too
- raise ValueError(
- f"Could not find checkpoint files '{pattern}', only "
- f"pre-sharded checkpoints are currently supported!")
- state_dict = self._filter_subtensors(model.state_dict())
- for path in filepaths:
- with safe_open(path, framework="pt") as f:
- for key in f.keys(): # noqa: SIM118
- tensor = f.get_tensor(key)
- # If loading with LoRA enabled, additional padding may
- # be added to certain parameters. We only load into a
- # narrowed view of the parameter data.
- param_data = state_dict[key].data
- param_shape = state_dict[key].shape
- for dim, size in enumerate(tensor.shape):
- if size < param_shape[dim]:
- param_data = param_data.narrow(dim, 0, size)
- if tensor.shape != param_shape:
- logger.warning("loading tensor of shape "
- f"{tensor.shape} into parameter "
- f"'{key}' of shape {param_shape}")
- param_data.copy_(tensor)
- state_dict.pop(key)
- if state_dict:
- raise ValueError(
- f"Missing keys {tuple(state_dict)} in loaded state!")
- return model.eval()
- @staticmethod
- def save_model(
- model: torch.nn.Module,
- path: str,
- pattern: Optional[str] = None,
- max_size: Optional[int] = None,
- ) -> None:
- from safetensors.torch import save_file
- from aphrodite.distributed import get_tensor_model_parallel_rank
- if pattern is None:
- pattern = ShardedStateLoader.DEFAULT_PATTERN
- rank = get_tensor_model_parallel_rank()
- part_idx = 0
- total_size = 0
- state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
- state_dict_part: Dict[str, torch.Tensor] = {}
- for key, tensor in state_dict.items():
- param_size = tensor.nelement() * tensor.element_size()
- if max_size is not None and total_size + param_size > max_size:
- filename = pattern.format(rank=rank, part=part_idx)
- save_file(
- state_dict_part,
- os.path.join(path, filename),
- )
- part_idx += 1
- total_size = 0
- state_dict_part = {}
- state_dict_part[key] = tensor
- total_size += param_size
- if len(state_dict_part) > 0:
- filename = pattern.format(rank=rank, part=part_idx)
- save_file(
- state_dict_part,
- os.path.join(path, filename),
- )
- class BitsAndBytesModelLoader(BaseModelLoader):
- """Model loader to load model weights with BitAndBytes quantization."""
- default_target_modules = [
- "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
- "o_proj"
- ]
- possible_config_file_names = ["adapter_config.json"]
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- # we don't need to quantize the whole model, only the target modules
- # that are specified in the adapter config file. If the adapter config
- # file is not provided, we will quantize the default modules.
- if (not load_config.model_loader_extra_config
- or "qlora_adapter_name_or_path"
- not in load_config.model_loader_extra_config):
- self.target_modules = self.default_target_modules
- return
- qlora_adapter = load_config.model_loader_extra_config[
- "qlora_adapter_name_or_path"]
- config_file_path = self._get_config_file(qlora_adapter)
- with open(config_file_path, "r") as f:
- config = json.load(f)
- self.target_modules = config["target_modules"]
- def _get_config_file(self, qlora_adapter: str) -> str:
- is_local = os.path.isdir(qlora_adapter)
- config_file_path = None
- if is_local:
- for file in self.possible_config_file_names:
- config_file_path = os.path.join(qlora_adapter, file)
- if os.path.exists(config_file_path):
- break
- else:
- hf_api = HfApi()
- repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
- for file in self.possible_config_file_names:
- if file in repo_files:
- config_file_path = hf_hub_download(repo_id=qlora_adapter,
- filename=file)
- break
- if not config_file_path:
- raise ValueError(
- f"Cannot find adapter config file in {qlora_adapter}")
- return config_file_path
- def _get_weight_files(
- self,
- model_name_or_path: str,
- allowed_patterns: List[str],
- revision: Optional[str] = None) -> Tuple[List[str], str]:
- """Retrieve weight files. Download the files if necessary.
-
- Return the weight files and the file pattern."""
- is_local = os.path.isdir(model_name_or_path)
- if is_local:
- for pattern in allowed_patterns:
- weight_files = glob.glob(
- os.path.join(model_name_or_path, pattern))
- if weight_files:
- return weight_files, pattern
- else:
- hf_api = HfApi()
- repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
- for pattern in allowed_patterns:
- matching_files = fnmatch.filter(repo_files, pattern)
- if matching_files:
- hf_folder = download_weights_from_hf(
- model_name_or_path,
- self.load_config.download_dir,
- [pattern],
- revision,
- ignore_patterns=self.load_config.ignore_patterns,
- )
- return glob.glob(os.path.join(hf_folder, pattern)), pattern
- raise RuntimeError(
- f"No model weights found in: `{model_name_or_path}`")
- def _prepare_weights(self, model_name_or_path: str,
- revision: Optional[str]) -> Tuple[List[str], bool]:
- """Prepare weight files for the model."""
- allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
- hf_weights_files, matched_pattern = self._get_weight_files(
- model_name_or_path, allowed_patterns, revision)
- if matched_pattern != "*.safetensors":
- hf_weights_files = filter_files_not_needed_for_inference(
- hf_weights_files)
- if len(hf_weights_files) == 0:
- raise RuntimeError(
- f"Cannot find any model weights with `{model_name_or_path}`")
- return hf_weights_files, matched_pattern == "*.safetensors"
- def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
- if use_safetensors:
- return safetensors_weights_iterator(hf_weights_files)
- else:
- return pt_weights_iterator(hf_weights_files)
- def _get_quantized_weights_iterator(
- self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
- ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
- Any]]:
- """Get an iterator to the model weights with bitsandbytes quantization,
- as well as the quantization state dictionary."""
- # only load the bitsandbytes module when needed
- try:
- import bitsandbytes
- from bitsandbytes.functional import QuantState
- if bitsandbytes.__version__ < "0.42.0":
- raise ImportError("bitsandbytes version is wrong. Please "
- "install bitsandbytes>=0.42.0.")
- from bitsandbytes.functional import quantize_4bit
- except ImportError as err:
- raise ImportError("Please install bitsandbytes>=0.42.0 via "
- "`pip install bitsandbytes>=0.42.0` to use "
- "bitsandbytes quantizer.") from err
- hf_weights_files, use_safetensors = self._prepare_weights(
- model_name_or_path, revision)
- quant_state_dict = {}
- def quantized_checkpoint() -> Generator:
- # First iterate over all quant state weights
- weight_iterator = self._hf_weight_iter(hf_weights_files,
- use_safetensors)
- temp_state_dict = {}
- for weight_name, weight_tensor in weight_iterator:
- if weight_name.endswith(".weight"):
- continue
- # TODO: only nf4 quantization is supported for now
- if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
- raise NotImplementedError(
- "Only bitsandbytes_nf4 quantization"
- f"is supported for now. {weight_name} is fp4 quantized"
- )
- temp_state_dict[weight_name] = weight_tensor
- # Closure to parse quant_state for each prequant weight
- def _parse_quant_state(param_name: str,
- temp_state_dict: Dict) -> QuantState:
- quant_state = {}
- for k in temp_state_dict:
- if param_name + "." in k:
- quant_state[k] = temp_state_dict[k]
- # bitsandbytes library requires
- # weight.quant_state.bitsandbytes__nf4 in CPU
- quant_state[param_name +
- ".quant_state.bitsandbytes__nf4"] = quant_state[
- param_name +
- ".quant_state.bitsandbytes__nf4"].cpu().data
- return QuantState.from_dict(quant_state, device="cuda")
- # Second iterate over all prequant and normal weights
- # pre quantized weights would have a quant_state
- for weight_name, weight_tensor in self._hf_weight_iter(
- hf_weights_files, use_safetensors):
- # Filter out all weights whose suffix is not ".weight"
- if not weight_name.endswith(".weight"):
- continue
- if weight_name + ".quant_state.bitsandbytes__nf4" \
- in temp_state_dict:
- quant_state = _parse_quant_state(weight_name,
- temp_state_dict)
- weight_name = weight_name.replace(".weight", ".qweight")
- quant_state_dict[weight_name] = quant_state
- yield weight_name.replace(".weight",
- ".qweight"), weight_tensor
- else:
- yield weight_name, weight_tensor
- def generator() -> Generator:
- for weight_name, weight_tensor in self._hf_weight_iter(
- hf_weights_files, use_safetensors):
- if any(target_module in weight_name
- for target_module in self.target_modules):
- weight_name = weight_name.replace(".weight", ".qweight")
- # bitsandbytes requires data in GPU
- loaded_weight = weight_tensor.cuda().data
- with set_default_torch_dtype(torch.float32):
- processed_weight, quant_state = quantize_4bit(
- loaded_weight,
- compress_statistics=True,
- quant_type="nf4")
- quant_state_dict[weight_name] = quant_state
- else:
- processed_weight = weight_tensor
- yield weight_name, processed_weight
- if pre_quant:
- return quantized_checkpoint(), quant_state_dict
- return generator(), quant_state_dict
- def _load_weights(self, model_config: ModelConfig,
- model: nn.Module) -> None:
- if not hasattr(model, 'load_weights'):
- raise AttributeError(
- "The required method 'load_weights' is not defined in class"
- f" {type(self).__name__}.")
- if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
- raise AttributeError(
- f"Model {type(self).__name__} does not support BitsAndBytes "
- "quantization yet.")
- logger.info("Loading weights with BitsAndBytes quantization. "
- "This May take a while ...")
- is_quantized_checkpoint = False
- quant_config = getattr(model_config.hf_config, "quantization_config",
- None)
- if quant_config is not None and quant_config.get(
- 'quant_method') == "bitsandbytes":
- is_quantized_checkpoint = True
- qweight_iterator, quant_state_dict = \
- self._get_quantized_weights_iterator(
- model_config.model, model_config.revision, is_quantized_checkpoint)
- model.load_weights(qweight_iterator)
- torch.cuda.empty_cache()
- param_dict = dict(model.named_parameters())
- stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
- for quant_param_name in quant_state_dict:
- non_stacked_param_name = quant_param_name
- shard_index = 0
- for shard_name, (
- weight_name, index
- ) in model.bitsandbytes_stacked_params_mapping.items():
- if shard_name in quant_param_name:
- shard_index = index
- quant_param_name = quant_param_name.replace(
- shard_name, weight_name)
- break
- if quant_param_name not in param_dict:
- raise ValueError(
- f"Parameter {quant_param_name} not found in the model.")
- if quant_param_name not in stacked_quant_state_dict:
- stacked_quant_state_dict[quant_param_name] = {}
- stacked_quant_state_dict[quant_param_name][shard_index] = (
- quant_state_dict[non_stacked_param_name])
- # save quant_states and offsets as the attributes of the parameters
- for param_name, param in param_dict.items():
- if param_name in stacked_quant_state_dict:
- quant_states = stacked_quant_state_dict[param_name]
- set_weight_attrs(param, {"bnb_quant_state": quant_states})
- pack_ratio = getattr(param, "pack_factor", -1)
- if pack_ratio == -1:
- raise ValueError(
- f"pack_factor not set for parameter {param_name}.")
- num_elements = [0] * len(quant_states)
- for seq, quant_state in quant_states.items():
- num_elements[seq] = math.prod(
- quant_state.shape) // pack_ratio
- offsets = np.concatenate(([0], np.cumsum(num_elements)))
- set_weight_attrs(param, {"bnb_shard_offsets": offsets})
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, cache_config)
- self._load_weights(model_config, model)
- return model.eval()
- class GGUFModelLoader(BaseModelLoader):
- """
- Model loader that can load GGUF files. This is useful for loading models
- that are quantized with GGUF and saved in the GGUF format. This loader
- supports loading both full models and sharded models.
- """
- def __init__(self, load_config: LoadConfig):
- super().__init__(load_config)
- if load_config.model_loader_extra_config:
- raise ValueError(f"Model loader extra config is not supported for "
- f"load format {load_config.load_format}")
- def _prepare_weights(self, model_name_or_path: str):
- if os.path.isfile(model_name_or_path):
- return model_name_or_path
- else:
- raise ValueError(f"{model_name_or_path} is not a file.")
- def _get_gguf_weights_map(self, model_config: ModelConfig):
- """
- GGUF uses this naming convention for their tensors from HF checkpoint:
- `blk.N.BB.weight` and `blk.N.BB.bias`
- where N signifies the block number of a layer, and BB signifies the
- attention/mlp layer components.
- See "Standardized tensor names" in
- https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
- """
- config = model_config.hf_config
- model_type = config.model_type
- # hack: ggufs have a different name than transformers
- if model_type == "cohere":
- model_type = "command-r"
- arch = None
- for key, value in gguf.MODEL_ARCH_NAMES.items():
- if value == model_type:
- arch = key
- break
- if arch is None:
- raise RuntimeError(f"Unknown gguf model_type: {model_type}")
- num_layers = config.num_hidden_layers
- name_map = gguf.get_tensor_name_map(arch, num_layers)
- with torch.device("meta"):
- dummy_model = AutoModelForCausalLM.from_config(config)
- state_dict = dummy_model.state_dict()
- gguf_to_hf_name_map = {}
- for hf_name in state_dict:
- name, suffix = hf_name.rsplit(".", 1)
- gguf_name = name_map.get_name(name)
- gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
- return gguf_to_hf_name_map
- def _get_weights_iterator(
- self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
- return gguf_quant_weights_iterator(model_name_or_path,
- gguf_to_hf_name_map)
- def load_model(self, *, model_config: ModelConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- cache_config: CacheConfig) -> nn.Module:
- local_model_path = self._prepare_weights(model_config.model)
- gguf_weights_map = self._get_gguf_weights_map(model_config)
- # we can only know if tie word embeddings after mapping weights
- if "lm_head.weight" in get_gguf_extra_tensor_names(
- local_model_path, gguf_weights_map):
- model_config.hf_config.update({"tie_word_embeddings": True})
- with set_default_torch_dtype(model_config.dtype):
- with torch.device(device_config.device):
- model = _initialize_model(model_config, self.load_config,
- lora_config, cache_config)
- model.load_weights(
- self._get_weights_iterator(local_model_path, gguf_weights_map))
- return model
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
- """Get a model loader based on the load format."""
- if isinstance(load_config.load_format, type):
- return load_config.load_format(load_config)
- if load_config.load_format == LoadFormat.DUMMY:
- return DummyModelLoader(load_config)
- if load_config.load_format == LoadFormat.TENSORIZER:
- return TensorizerLoader(load_config)
- if load_config.load_format == LoadFormat.SHARDED_STATE:
- return ShardedStateLoader(load_config)
- if load_config.load_format == LoadFormat.BITSANDBYTES:
- return BitsAndBytesModelLoader(load_config)
- if load_config.load_format == LoadFormat.GGUF:
- return GGUFModelLoader(load_config)
- return DefaultModelLoader(load_config)
|