loader.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141
  1. # ruff: noqa: SIM117
  2. import collections
  3. import copy
  4. import fnmatch
  5. import glob
  6. import json
  7. import math
  8. import os
  9. from abc import ABC, abstractmethod
  10. from contextlib import contextmanager
  11. from typing import Any, Dict, Generator, List, Optional, Tuple, Type
  12. import gguf
  13. import huggingface_hub
  14. import numpy as np
  15. import torch
  16. from huggingface_hub import HfApi, hf_hub_download
  17. from loguru import logger
  18. from torch import nn
  19. from transformers import AutoModelForCausalLM, PretrainedConfig
  20. from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
  21. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
  22. DeviceConfig, LoadConfig, LoadFormat,
  23. LoRAConfig, ModelConfig, MultiModalConfig,
  24. ParallelConfig, SchedulerConfig)
  25. from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
  26. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  27. get_tensor_model_parallel_world_size)
  28. from aphrodite.modeling.model_loader.tensorizer import (
  29. TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
  30. serialize_aphrodite_model, tensorizer_weights_iterator)
  31. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  32. set_default_torch_dtype)
  33. from aphrodite.modeling.model_loader.weight_utils import (
  34. download_safetensors_index_file_from_hf, download_weights_from_hf,
  35. filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
  36. get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
  37. initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
  38. safetensors_weights_iterator)
  39. from aphrodite.modeling.models.interfaces import (has_inner_state,
  40. supports_lora,
  41. supports_multimodal)
  42. from aphrodite.modeling.utils import set_weight_attrs
  43. from aphrodite.platforms import current_platform
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. @contextmanager
  46. def device_loading_context(module: torch.nn.Module,
  47. target_device: torch.device):
  48. if target_device.type == "cpu":
  49. # If target is CPU, no need to move anything
  50. yield module
  51. return
  52. original_device_states: Dict[str, torch.device] = {}
  53. # Store original device states and move parameters to GPU if they're on CPU
  54. for name, p in module.named_parameters():
  55. if p.device.type == "cpu":
  56. original_device_states[name] = p.device
  57. p.data = p.data.to(target_device)
  58. # Parameters already on target device are not touched
  59. try:
  60. yield module
  61. finally:
  62. # Restore parameters to their original devices, ignoring new parameters
  63. pin_memory = is_pin_memory_available()
  64. for name, p in module.named_parameters():
  65. if name in original_device_states:
  66. original_device: torch.device = original_device_states[name]
  67. if original_device.type == "cpu":
  68. # `torch.empty_like` does not support `pin_memory` argument
  69. cpu_data = torch.empty_strided(size=p.data.size(),
  70. stride=p.data.stride(),
  71. dtype=p.data.dtype,
  72. layout=p.data.layout,
  73. device="cpu",
  74. pin_memory=pin_memory)
  75. cpu_data.copy_(p.data)
  76. p.data = cpu_data
  77. else:
  78. p.data = p.data.to(original_device)
  79. # New parameters or parameters already on target device are untouched
  80. def _get_quantization_config(
  81. model_config: ModelConfig,
  82. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  83. """Get the quantization config."""
  84. if model_config.quantization is not None:
  85. quant_config = get_quant_config(model_config, load_config)
  86. capability = current_platform.get_device_capability() # type: ignore
  87. if capability is not None:
  88. capability = capability[0] * 10 + capability[1]
  89. if capability < quant_config.get_min_capability():
  90. raise ValueError(
  91. f"The quantization method {model_config.quantization} "
  92. "is not supported for the current GPU. "
  93. f"Minimum capability: {quant_config.get_min_capability()}. "
  94. f"Current capability: {capability}.")
  95. supported_dtypes = quant_config.get_supported_act_dtypes()
  96. if model_config.dtype not in supported_dtypes:
  97. raise ValueError(
  98. f"{model_config.dtype} is not supported for quantization "
  99. f"method {model_config.quantization}. Supported dtypes: "
  100. f"{supported_dtypes}")
  101. return quant_config
  102. return None
  103. def _get_model_initialization_kwargs(
  104. model_class: Type[nn.Module],
  105. lora_config: Optional[LoRAConfig],
  106. multimodal_config: Optional[MultiModalConfig],
  107. scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
  108. """Get extra kwargs for model initialization."""
  109. extra_kwargs: Dict[str, Any] = {}
  110. if supports_lora(model_class):
  111. # lora_config=None is used to disable LoRA
  112. extra_kwargs["lora_config"] = lora_config
  113. elif lora_config:
  114. raise ValueError(
  115. f"Model {model_class.__name__} does not support LoRA, "
  116. "but LoRA is enabled. Support for this model may "
  117. "be added in the future. If this is important to you, "
  118. "please open an issue on github.")
  119. if supports_multimodal(model_class):
  120. assert multimodal_config is not None
  121. extra_kwargs["multimodal_config"] = multimodal_config
  122. if has_inner_state(model_class) and scheduler_config:
  123. extra_kwargs["scheduler_config"] = scheduler_config
  124. return extra_kwargs
  125. def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
  126. cache_config: Optional[CacheConfig],
  127. quant_config: Optional[QuantizationConfig], *,
  128. lora_config: Optional[LoRAConfig],
  129. multimodal_config: Optional[MultiModalConfig],
  130. scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
  131. extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
  132. multimodal_config,
  133. scheduler_config)
  134. return model_class(config=hf_config,
  135. cache_config=cache_config,
  136. quant_config=quant_config,
  137. **extra_kwargs)
  138. def _initialize_model(
  139. model_config: ModelConfig,
  140. load_config: LoadConfig,
  141. lora_config: Optional[LoRAConfig],
  142. cache_config: CacheConfig,
  143. scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
  144. """Initialize a model with the given configurations."""
  145. model_class, _ = get_model_architecture(model_config)
  146. return build_model(
  147. model_class,
  148. model_config.hf_config,
  149. cache_config=cache_config,
  150. quant_config=_get_quantization_config(model_config, load_config),
  151. lora_config=lora_config,
  152. multimodal_config=model_config.multimodal_config,
  153. scheduler_config=scheduler_config,
  154. )
  155. class BaseModelLoader(ABC):
  156. """Base class for model loaders."""
  157. def __init__(self, load_config: LoadConfig):
  158. self.load_config = load_config
  159. @abstractmethod
  160. def load_model(self, *, model_config: ModelConfig,
  161. device_config: DeviceConfig,
  162. lora_config: Optional[LoRAConfig],
  163. parallel_config: ParallelConfig,
  164. scheduler_config: SchedulerConfig,
  165. cache_config: CacheConfig) -> nn.Module:
  166. """Load a model with the given configurations."""
  167. ...
  168. class DefaultModelLoader(BaseModelLoader):
  169. """Model loader that can load different file types from disk."""
  170. def __init__(self, load_config: LoadConfig):
  171. super().__init__(load_config)
  172. if load_config.model_loader_extra_config:
  173. raise ValueError(f"Model loader extra config is not supported for "
  174. f"load format {load_config.load_format}")
  175. def _maybe_download_from_modelscope(
  176. self, model: str, revision: Optional[str]) -> Optional[str]:
  177. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  178. True.
  179. Returns the path to the downloaded model, or None if the model is not
  180. downloaded from ModelScope."""
  181. if APHRODITE_USE_MODELSCOPE:
  182. # download model from ModelScope hub,
  183. # lazy import so that modelscope is not required for normal use.
  184. # pylint: disable=C.
  185. from modelscope.hub.snapshot_download import snapshot_download
  186. if not os.path.exists(model):
  187. model_path = snapshot_download(
  188. model_id=model,
  189. cache_dir=self.load_config.download_dir,
  190. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  191. revision=revision,
  192. ignore_file_pattern=self.load_config.ignore_patterns,
  193. )
  194. else:
  195. model_path = model
  196. return model_path
  197. return None
  198. def _prepare_weights(self, model_name_or_path: str,
  199. revision: Optional[str],
  200. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  201. """Prepare weights for the model.
  202. If the model is not local, it will be downloaded."""
  203. model_name_or_path = self._maybe_download_from_modelscope(
  204. model_name_or_path, revision) or model_name_or_path
  205. is_local = os.path.isdir(model_name_or_path)
  206. load_format = self.load_config.load_format
  207. use_safetensors = False
  208. index_file = SAFE_WEIGHTS_INDEX_NAME
  209. # Some quantized models use .pt files for storing the weights.
  210. if load_format == LoadFormat.AUTO:
  211. allow_patterns = ["*.safetensors", "*.bin"]
  212. elif load_format == LoadFormat.SAFETENSORS:
  213. use_safetensors = True
  214. allow_patterns = ["*.safetensors"]
  215. elif load_format == LoadFormat.MISTRAL:
  216. use_safetensors = True
  217. allow_patterns = ["consolidated*.safetensors"]
  218. index_file = "consolidated.safetensors.index.json"
  219. elif load_format == LoadFormat.PT:
  220. allow_patterns = ["*.pt"]
  221. elif load_format == LoadFormat.NPCACHE:
  222. allow_patterns = ["*.bin"]
  223. else:
  224. raise ValueError(f"Unknown load_format: {load_format}")
  225. if fall_back_to_pt:
  226. allow_patterns += ["*.pt"]
  227. if not is_local:
  228. hf_folder = download_weights_from_hf(
  229. model_name_or_path,
  230. self.load_config.download_dir,
  231. allow_patterns,
  232. revision,
  233. ignore_patterns=self.load_config.ignore_patterns,
  234. )
  235. else:
  236. hf_folder = model_name_or_path
  237. hf_weights_files: List[str] = []
  238. for pattern in allow_patterns:
  239. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  240. if len(hf_weights_files) > 0:
  241. if pattern == "*.safetensors":
  242. use_safetensors = True
  243. break
  244. if use_safetensors:
  245. # For models like Mistral-7B-Instruct-v0.3
  246. # there are both sharded safetensors files and a consolidated
  247. # safetensors file. Using both breaks.
  248. # Here, we download the `model.safetensors.index.json` and filter
  249. # any files not found in the index.
  250. if not is_local:
  251. download_safetensors_index_file_from_hf(
  252. model_name_or_path, index_file,
  253. self.load_config.download_dir, revision)
  254. hf_weights_files = filter_duplicate_safetensors_files(
  255. hf_weights_files, hf_folder, index_file)
  256. else:
  257. hf_weights_files = filter_files_not_needed_for_inference(
  258. hf_weights_files)
  259. if len(hf_weights_files) == 0:
  260. raise RuntimeError(
  261. f"Cannot find any model weights with `{model_name_or_path}`")
  262. return hf_folder, hf_weights_files, use_safetensors
  263. def _get_weights_iterator(
  264. self, model_name_or_path: str, revision: Optional[str],
  265. fall_back_to_pt: bool
  266. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], int]:
  267. """Get an iterator for the model weights based on the load format."""
  268. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  269. model_name_or_path, revision, fall_back_to_pt)
  270. est_weight_bytes = sum(os.path.getsize(f)
  271. for f in hf_weights_files)
  272. if self.load_config.load_format == LoadFormat.NPCACHE:
  273. # Currently np_cache only support *.bin checkpoints
  274. assert use_safetensors is False
  275. weights_iterator = np_cache_weights_iterator(
  276. model_name_or_path, self.load_config.download_dir, hf_folder,
  277. hf_weights_files)
  278. elif use_safetensors:
  279. weights_iterator = safetensors_weights_iterator(hf_weights_files)
  280. else:
  281. weights_iterator = pt_weights_iterator(hf_weights_files)
  282. if current_platform.is_tpu():
  283. # In PyTorch XLA, we should call `xm.mark_step` frequently so that
  284. # not too many ops are accumulated in the XLA program.
  285. import torch_xla.core.xla_model as xm
  286. def _xla_weights_iterator(iterator: Generator):
  287. for weights in iterator:
  288. yield weights
  289. xm.mark_step()
  290. weights_iterator = _xla_weights_iterator(weights_iterator)
  291. return weights_iterator, est_weight_bytes
  292. def load_model(self, *, model_config: ModelConfig,
  293. device_config: DeviceConfig,
  294. lora_config: Optional[LoRAConfig],
  295. parallel_config: ParallelConfig,
  296. scheduler_config: SchedulerConfig,
  297. cache_config: CacheConfig) -> nn.Module:
  298. target_device = torch.device(device_config.device)
  299. with set_default_torch_dtype(model_config.dtype):
  300. with target_device:
  301. model = _initialize_model(model_config, self.load_config,
  302. lora_config, cache_config,
  303. scheduler_config)
  304. weights, wgt_bytes = self._get_weights_iterator(model_config.model,
  305. model_config.revision,
  306. fall_back_to_pt=getattr(
  307. model,
  308. "fall_back_to_pt_during_load",
  309. True))
  310. model.load_weights(tensor_progress_bar(weights, wgt_bytes,
  311. "Loading model weights..."))
  312. for _, module in model.named_modules():
  313. quant_method = getattr(module, "quant_method", None)
  314. if quant_method is not None:
  315. # When quant methods need to process weights after loading
  316. # (for repacking, quantizing, etc), they expect parameters
  317. # to be on the global target device. This scope is for the
  318. # case where cpu offloading is used, where we will move the
  319. # parameters onto device for processing and back off after.
  320. with device_loading_context(module, target_device):
  321. quant_method.process_weights_after_loading(module)
  322. return model.eval()
  323. class DummyModelLoader(BaseModelLoader):
  324. """Model loader that will set model weights to random values."""
  325. def __init__(self, load_config: LoadConfig):
  326. super().__init__(load_config)
  327. if load_config.model_loader_extra_config:
  328. raise ValueError(f"Model loader extra config is not supported for "
  329. f"load format {load_config.load_format}")
  330. def load_model(self, *, model_config: ModelConfig,
  331. device_config: DeviceConfig,
  332. lora_config: Optional[LoRAConfig],
  333. parallel_config: ParallelConfig,
  334. scheduler_config: SchedulerConfig,
  335. cache_config: CacheConfig) -> nn.Module:
  336. with set_default_torch_dtype(model_config.dtype):
  337. with torch.device(device_config.device):
  338. model = _initialize_model(model_config, self.load_config,
  339. lora_config, cache_config,
  340. scheduler_config)
  341. # NOTE: For accurate performance evaluation, we assign
  342. # random values to the weights.
  343. initialize_dummy_weights(model)
  344. return model.eval()
  345. class TensorizerLoader(BaseModelLoader):
  346. """Model loader using CoreWeave's tensorizer library."""
  347. def __init__(self, load_config: LoadConfig):
  348. super().__init__(load_config)
  349. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  350. self.tensorizer_config = load_config.model_loader_extra_config
  351. else:
  352. self.tensorizer_config = TensorizerConfig(
  353. **load_config.model_loader_extra_config)
  354. def _verify_config(self, model_config: ModelConfig,
  355. parallel_config: ParallelConfig):
  356. self.tensorizer_config.verify_with_model_config(model_config)
  357. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  358. def _get_weights_iterator(
  359. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  360. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  361. return tensorizer_weights_iterator(tensorizer_args)
  362. def _load_model_serialized_cpu(
  363. self,
  364. model_config: ModelConfig,
  365. device_config: DeviceConfig,
  366. lora_config: Optional[LoRAConfig],
  367. cache_config: CacheConfig,
  368. ) -> nn.Module:
  369. """Load a serialized model with tensorizer to the CPU.
  370. This is only necessary when the model isn't Aphrodite-tensorized (see
  371. examples/tensorize_aphrodite_model.py) This should still be faster than
  372. default HuggingFace loading, but will be slower than loading a
  373. Aphrodite-tensorized model.
  374. """
  375. with set_default_torch_dtype(model_config.dtype):
  376. with torch.device(device_config.device):
  377. model = _initialize_model(model_config, self.load_config,
  378. lora_config, cache_config)
  379. model.load_weights(self._get_weights_iterator())
  380. return model.eval()
  381. def _load_model_serialized(
  382. self,
  383. model_config: ModelConfig,
  384. device_config: DeviceConfig,
  385. lora_config: Optional[LoRAConfig],
  386. cache_config: CacheConfig,
  387. ) -> nn.Module:
  388. """Load a serialized model with tensorizer.
  389. Expects a Aphrodite-tensorized model. See the
  390. examples/tensorize_aphrodite_model.py example script
  391. for serializing Aphrodite models."""
  392. with set_default_torch_dtype(model_config.dtype):
  393. with torch.device(device_config.device):
  394. model_class = get_model_architecture(model_config)[0]
  395. quant_config = _get_quantization_config(
  396. model_config, self.load_config)
  397. extra_kwargs = _get_model_initialization_kwargs(
  398. model_class, lora_config, model_config.multimodal_config)
  399. extra_kwargs["quant_config"] = quant_config
  400. extra_kwargs["cache_config"] = cache_config
  401. tensorizer_config = copy.copy(self.tensorizer_config)
  402. tensorizer_config.model_class = model_class
  403. tensorizer_config.hf_config = model_config.hf_config
  404. tensorizer_config.dtype = model_config.dtype
  405. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  406. return model.eval()
  407. def load_model(self, *, model_config: ModelConfig,
  408. device_config: DeviceConfig,
  409. lora_config: Optional[LoRAConfig],
  410. parallel_config: ParallelConfig,
  411. scheduler_config: SchedulerConfig,
  412. cache_config: CacheConfig) -> nn.Module:
  413. self._verify_config(model_config, parallel_config)
  414. if parallel_config.tensor_parallel_size > 1:
  415. from aphrodite.distributed import get_tensor_model_parallel_rank
  416. self.tensorizer_config.tensorizer_uri = \
  417. self.tensorizer_config.tensorizer_uri \
  418. % get_tensor_model_parallel_rank()
  419. if is_aphrodite_tensorized(self.tensorizer_config):
  420. return self._load_model_serialized(model_config, device_config,
  421. lora_config, cache_config)
  422. return self._load_model_serialized_cpu(model_config, device_config,
  423. lora_config, cache_config)
  424. @staticmethod
  425. def save_model(
  426. model: torch.nn.Module,
  427. tensorizer_config: TensorizerConfig,
  428. ) -> None:
  429. serialize_aphrodite_model(
  430. model=model,
  431. tensorizer_config=tensorizer_config,
  432. )
  433. class ShardedStateLoader(BaseModelLoader):
  434. """
  435. Model loader that directly loads each worker's model state dict, which
  436. enables a fast load path for large tensor-parallel models where each worker
  437. only needs to read its own shard rather than the entire checkpoint. See
  438. `examples/save_sharded_state.py` for creating a sharded checkpoint.
  439. """
  440. DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
  441. def __init__(self, load_config: LoadConfig):
  442. super().__init__(load_config)
  443. extra_config = ({} if load_config.model_loader_extra_config is None
  444. else load_config.model_loader_extra_config.copy())
  445. self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
  446. if extra_config:
  447. raise ValueError(f"Unexpected extra config keys for load format "
  448. f"{load_config.load_format}: "
  449. f"{load_config.model_loader_extra_config.keys()}")
  450. @staticmethod
  451. def _filter_subtensors(
  452. tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  453. """
  454. Filter out all tensors that share the same memory or a subset of the
  455. memory of another tensor.
  456. """
  457. same_storage_groups: Dict[Any, List[Tuple[
  458. str, torch.Tensor]]] = collections.defaultdict(list)
  459. for key, tensor in tensors.items():
  460. if tensor.numel():
  461. ptr = tensor.untyped_storage().data_ptr()
  462. same_storage_groups[tensor.device, ptr].append((key, tensor))
  463. def get_end_ptr(tensor: torch.Tensor) -> int:
  464. return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  465. result: Dict[str, torch.Tensor] = {}
  466. for group in same_storage_groups.values():
  467. for k, t in group:
  468. a, b = t.data_ptr(), get_end_ptr(t)
  469. for k2, t2 in group:
  470. if not t2.is_contiguous():
  471. continue
  472. a2, b2 = t2.data_ptr(), get_end_ptr(t2)
  473. if a < a2 or b2 < b:
  474. continue
  475. if a2 < a or b < b2 or not t.is_contiguous():
  476. break # t2 covers strictly more memory than t.
  477. if k2 < k:
  478. # Same tensors, keep the one with the smaller key.
  479. break
  480. else:
  481. result[k] = t
  482. return result
  483. def _prepare_weights(self, model_name_or_path: str,
  484. revision: Optional[str]):
  485. if os.path.isdir(model_name_or_path):
  486. return model_name_or_path
  487. else:
  488. allow_patterns = ["*.safetensors"]
  489. return download_weights_from_hf(
  490. model_name_or_path,
  491. self.load_config.download_dir,
  492. allow_patterns,
  493. revision,
  494. ignore_patterns=self.load_config.ignore_patterns,
  495. )
  496. def load_model(self, *, model_config: ModelConfig,
  497. device_config: DeviceConfig,
  498. lora_config: Optional[LoRAConfig],
  499. parallel_config: ParallelConfig,
  500. scheduler_config: SchedulerConfig,
  501. cache_config: CacheConfig) -> nn.Module:
  502. from safetensors.torch import safe_open
  503. from aphrodite.distributed import get_tensor_model_parallel_rank
  504. local_model_path = self._prepare_weights(model_config.model,
  505. model_config.revision)
  506. with set_default_torch_dtype(model_config.dtype):
  507. with torch.device(device_config.device):
  508. model = _initialize_model(model_config, self.load_config,
  509. lora_config, cache_config)
  510. for _, module in model.named_modules():
  511. quant_method = getattr(module, "quant_method", None)
  512. if quant_method is not None:
  513. quant_method.process_weights_after_loading(module)
  514. rank = get_tensor_model_parallel_rank()
  515. pattern = os.path.join(
  516. local_model_path,
  517. self.pattern.format(rank=rank, part="*"),
  518. )
  519. filepaths = glob.glob(pattern)
  520. if not filepaths:
  521. # TODO: support un-sharded checkpoints too
  522. raise ValueError(
  523. f"Could not find checkpoint files '{pattern}', only "
  524. f"pre-sharded checkpoints are currently supported!")
  525. state_dict = self._filter_subtensors(model.state_dict())
  526. for path in filepaths:
  527. with safe_open(path, framework="pt") as f:
  528. for key in f.keys(): # noqa: SIM118
  529. tensor = f.get_tensor(key)
  530. # If loading with LoRA enabled, additional padding may
  531. # be added to certain parameters. We only load into a
  532. # narrowed view of the parameter data.
  533. param_data = state_dict[key].data
  534. param_shape = state_dict[key].shape
  535. for dim, size in enumerate(tensor.shape):
  536. if size < param_shape[dim]:
  537. param_data = param_data.narrow(dim, 0, size)
  538. if tensor.shape != param_shape:
  539. logger.warning("loading tensor of shape "
  540. f"{tensor.shape} into parameter "
  541. f"'{key}' of shape {param_shape}")
  542. param_data.copy_(tensor)
  543. state_dict.pop(key)
  544. if state_dict:
  545. raise ValueError(
  546. f"Missing keys {tuple(state_dict)} in loaded state!")
  547. return model.eval()
  548. @staticmethod
  549. def save_model(
  550. model: torch.nn.Module,
  551. path: str,
  552. pattern: Optional[str] = None,
  553. max_size: Optional[int] = None,
  554. ) -> None:
  555. from safetensors.torch import save_file
  556. from aphrodite.distributed import get_tensor_model_parallel_rank
  557. if pattern is None:
  558. pattern = ShardedStateLoader.DEFAULT_PATTERN
  559. rank = get_tensor_model_parallel_rank()
  560. part_idx = 0
  561. total_size = 0
  562. state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
  563. state_dict_part: Dict[str, torch.Tensor] = {}
  564. for key, tensor in state_dict.items():
  565. param_size = tensor.nelement() * tensor.element_size()
  566. if max_size is not None and total_size + param_size > max_size:
  567. filename = pattern.format(rank=rank, part=part_idx)
  568. save_file(
  569. state_dict_part,
  570. os.path.join(path, filename),
  571. )
  572. part_idx += 1
  573. total_size = 0
  574. state_dict_part = {}
  575. state_dict_part[key] = tensor
  576. total_size += param_size
  577. if len(state_dict_part) > 0:
  578. filename = pattern.format(rank=rank, part=part_idx)
  579. save_file(
  580. state_dict_part,
  581. os.path.join(path, filename),
  582. )
  583. class BitsAndBytesModelLoader(BaseModelLoader):
  584. """Model loader to load model weights with BitAndBytes quantization."""
  585. # TODO: these module names are for Llama only,
  586. # change so that it works with other models as well
  587. default_target_modules = [
  588. "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
  589. "o_proj"
  590. ]
  591. possible_config_file_names = ["adapter_config.json"]
  592. def __init__(self, load_config: LoadConfig):
  593. super().__init__(load_config)
  594. # we don't need to quantize the whole model, only the target modules
  595. # that are specified in the adapter config file. If the adapter config
  596. # file is not provided, we will quantize the default modules.
  597. if (not load_config.model_loader_extra_config
  598. or "qlora_adapter_name_or_path"
  599. not in load_config.model_loader_extra_config):
  600. self.target_modules = self.default_target_modules
  601. return
  602. qlora_adapter = load_config.model_loader_extra_config[
  603. "qlora_adapter_name_or_path"]
  604. config_file_path = self._get_config_file(qlora_adapter)
  605. with open(config_file_path, "r") as f:
  606. config = json.load(f)
  607. self.target_modules = config["target_modules"]
  608. def _get_config_file(self, qlora_adapter: str) -> str:
  609. is_local = os.path.isdir(qlora_adapter)
  610. config_file_path = None
  611. if is_local:
  612. for file in self.possible_config_file_names:
  613. config_file_path = os.path.join(qlora_adapter, file)
  614. if os.path.exists(config_file_path):
  615. break
  616. else:
  617. hf_api = HfApi()
  618. repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
  619. for file in self.possible_config_file_names:
  620. if file in repo_files:
  621. config_file_path = hf_hub_download(repo_id=qlora_adapter,
  622. filename=file)
  623. break
  624. if not config_file_path:
  625. raise ValueError(
  626. f"Cannot find adapter config file in {qlora_adapter}")
  627. return config_file_path
  628. def _get_weight_files(
  629. self,
  630. model_name_or_path: str,
  631. allowed_patterns: List[str],
  632. revision: Optional[str] = None) -> Tuple[List[str], str]:
  633. """Retrieve weight files. Download the files if necessary.
  634. Return the weight files and the file pattern."""
  635. is_local = os.path.isdir(model_name_or_path)
  636. if is_local:
  637. for pattern in allowed_patterns:
  638. weight_files = glob.glob(
  639. os.path.join(model_name_or_path, pattern))
  640. if weight_files:
  641. return weight_files, pattern
  642. else:
  643. hf_api = HfApi()
  644. repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
  645. for pattern in allowed_patterns:
  646. matching_files = fnmatch.filter(repo_files, pattern)
  647. if matching_files:
  648. hf_folder = download_weights_from_hf(
  649. model_name_or_path,
  650. self.load_config.download_dir,
  651. [pattern],
  652. revision,
  653. ignore_patterns=self.load_config.ignore_patterns,
  654. )
  655. return glob.glob(os.path.join(hf_folder, pattern)), pattern
  656. raise RuntimeError(
  657. f"No model weights found in: `{model_name_or_path}`")
  658. def _prepare_weights(self, model_name_or_path: str,
  659. revision: Optional[str]) -> Tuple[List[str], bool]:
  660. """Prepare weight files for the model."""
  661. allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
  662. hf_weights_files, matched_pattern = self._get_weight_files(
  663. model_name_or_path, allowed_patterns, revision)
  664. if matched_pattern != "*.safetensors":
  665. hf_weights_files = filter_files_not_needed_for_inference(
  666. hf_weights_files)
  667. if len(hf_weights_files) == 0:
  668. raise RuntimeError(
  669. f"Cannot find any model weights with `{model_name_or_path}`")
  670. return hf_weights_files, matched_pattern == "*.safetensors"
  671. def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
  672. if use_safetensors:
  673. return safetensors_weights_iterator(hf_weights_files)
  674. else:
  675. return pt_weights_iterator(hf_weights_files)
  676. def _get_quantized_weights_iterator(
  677. self,
  678. model_name_or_path: str,
  679. revision: Optional[str],
  680. pre_quant: bool,
  681. load_8bit: bool,
  682. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
  683. Any]]:
  684. """Get an iterator to the model weights with bitsandbytes quantization,
  685. as well as the quantization state dictionary."""
  686. # only load the bitsandbytes module when needed
  687. try:
  688. import bitsandbytes
  689. if bitsandbytes.__version__ < "0.42.0":
  690. raise ImportError("bitsandbytes version is wrong. Please "
  691. "install bitsandbytes>=0.42.0.")
  692. except ImportError as err:
  693. raise ImportError("Please install bitsandbytes>=0.42.0 via "
  694. "`pip install bitsandbytes>=0.42.0` to use "
  695. "bitsandbytes quantizer.") from err
  696. hf_weights_files, use_safetensors = self._prepare_weights(
  697. model_name_or_path, revision)
  698. quant_state_dict: Dict[str, Any] = {}
  699. if pre_quant:
  700. if load_8bit:
  701. return self._quantized_8bit_generator(
  702. hf_weights_files, use_safetensors,
  703. quant_state_dict), quant_state_dict
  704. else:
  705. return self._quantized_4bit_generator(
  706. hf_weights_files, use_safetensors,
  707. quant_state_dict), quant_state_dict
  708. return self._unquantized_generator(hf_weights_files, use_safetensors,
  709. quant_state_dict), quant_state_dict
  710. def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
  711. quant_state_dict) -> Generator:
  712. for weight_name, weight_tensor in self._hf_weight_iter(
  713. hf_weights_files, use_safetensors):
  714. if not weight_name.lower().endswith(".scb"):
  715. continue
  716. weight_key = weight_name.lower().replace(".scb", ".qweight")
  717. quant_state_dict[weight_key] = weight_tensor
  718. for weight_name, weight_tensor in self._hf_weight_iter(
  719. hf_weights_files, use_safetensors):
  720. if not weight_name.endswith(".weight"):
  721. continue
  722. qweight_name = weight_name.replace(".weight", ".qweight")
  723. if qweight_name in quant_state_dict:
  724. set_weight_attrs(weight_tensor, {"load_in_8bit": True})
  725. yield qweight_name, weight_tensor
  726. else:
  727. yield weight_name, weight_tensor
  728. def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
  729. quant_state_dict) -> Generator:
  730. from bitsandbytes.functional import QuantState
  731. # First iterate over all quant state weights
  732. weight_iterator = self._hf_weight_iter(hf_weights_files,
  733. use_safetensors)
  734. temp_state_dict = {}
  735. for weight_name, weight_tensor in weight_iterator:
  736. if weight_name.endswith(".weight"):
  737. continue
  738. # bitsandbytes library requires
  739. # weight.quant_state.bitsandbytes__* in CPU
  740. if "quant_state.bitsandbytes" in weight_name:
  741. temp_state_dict[weight_name] = weight_tensor.cpu().data
  742. else:
  743. temp_state_dict[weight_name] = weight_tensor
  744. # Closure to parse quant_state for each prequant weight
  745. def _parse_quant_state(param_name: str,
  746. temp_state_dict: Dict) -> QuantState:
  747. quant_state = {}
  748. for k in temp_state_dict:
  749. if param_name + "." in k:
  750. quant_state[k] = temp_state_dict[k]
  751. return QuantState.from_dict(quant_state, device="cuda")
  752. # Second iterate over all prequant and normal weights
  753. # pre quantized weights would have a quant_state
  754. for weight_name, weight_tensor in self._hf_weight_iter(
  755. hf_weights_files, use_safetensors):
  756. # Filter out all weights whose suffix is not ".weight"
  757. if not weight_name.endswith(".weight"):
  758. continue
  759. if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
  760. in temp_state_dict) or \
  761. (f"{weight_name}.quant_state.bitsandbytes__fp4" \
  762. in temp_state_dict):
  763. quant_state = _parse_quant_state(weight_name, temp_state_dict)
  764. weight_name = weight_name.replace(".weight", ".qweight")
  765. quant_state_dict[weight_name] = quant_state
  766. yield weight_name.replace(".weight", ".qweight"), weight_tensor
  767. else:
  768. yield weight_name, weight_tensor
  769. def _unquantized_generator(self, hf_weights_files, use_safetensors,
  770. quant_state_dict) -> Generator:
  771. from bitsandbytes.functional import quantize_4bit
  772. tp_size = get_tensor_model_parallel_world_size()
  773. tp_rank = get_tensor_model_parallel_rank()
  774. for weight_name, weight_tensor in self._hf_weight_iter(
  775. hf_weights_files, use_safetensors):
  776. if any(target_module in weight_name
  777. for target_module in self.target_modules):
  778. weight_name = weight_name.replace(".weight", ".qweight")
  779. # weight partitions of different modules occur at
  780. # different dimensions
  781. # TODO: these module names are for Llama only,
  782. # change so that it works with other models as well
  783. if 'down_proj' in weight_name or 'o_proj' in weight_name:
  784. total_size = weight_tensor.size(-1)
  785. start_index = total_size // tp_size * tp_rank
  786. end_index = total_size // tp_size * (tp_rank + 1)
  787. weight_sub_tensor = weight_tensor[...,
  788. start_index:end_index]
  789. else:
  790. total_size = weight_tensor.size(0)
  791. start_index = total_size // tp_size * tp_rank
  792. end_index = total_size // tp_size * (tp_rank + 1)
  793. weight_sub_tensor = weight_tensor[start_index:end_index,
  794. ...]
  795. # bitsandbytes requires data in GPU
  796. if weight_sub_tensor.is_cuda:
  797. loaded_weight = weight_sub_tensor
  798. else:
  799. loaded_weight = weight_sub_tensor.cuda()
  800. # remove the following after the issue is fixed:
  801. # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
  802. if loaded_weight.is_contiguous() is False:
  803. loaded_weight = loaded_weight.contiguous()
  804. with set_default_torch_dtype(torch.float32):
  805. processed_weight, quant_state = quantize_4bit(
  806. loaded_weight,
  807. compress_statistics=True,
  808. quant_type="nf4")
  809. quant_state_dict[weight_name] = quant_state
  810. else:
  811. processed_weight = weight_tensor
  812. yield weight_name, processed_weight
  813. def _load_weights(self, model_config: ModelConfig,
  814. model: nn.Module) -> None:
  815. if not hasattr(model, 'load_weights'):
  816. raise AttributeError(
  817. "The required method 'load_weights' is not defined in class"
  818. f" {type(self).__name__}.")
  819. if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
  820. raise AttributeError(
  821. f"Model {type(self).__name__} does not support BitsAndBytes "
  822. "quantization yet.")
  823. logger.info("Loading weights with BitsAndBytes quantization. "
  824. "This May take a while ...")
  825. quant_config = getattr(model_config.hf_config, "quantization_config",
  826. None)
  827. pre_quant = False
  828. if quant_config is not None:
  829. quant_method = quant_config.get('quant_method')
  830. if quant_method == "bitsandbytes":
  831. pre_quant = True
  832. else:
  833. raise ValueError(
  834. f"BitsAndBytes loader does not support {quant_method} "
  835. "quantization")
  836. # The quant_states in pre_quantized models cannot work with a split
  837. # weight tensor. So TP does not work with pre_quantized bnb models.
  838. if pre_quant and get_tensor_model_parallel_world_size() > 1:
  839. raise ValueError(
  840. "Prequant BitsAndBytes models with TP is not supported."
  841. "Please try with PP.")
  842. load_8bit = False
  843. if pre_quant:
  844. load_8bit = quant_config.get('load_in_8bit', False)
  845. qweight_iterator, quant_state_dict = \
  846. self._get_quantized_weights_iterator(
  847. model_config.model, model_config.revision, pre_quant, load_8bit)
  848. model.load_weights(qweight_iterator)
  849. torch.cuda.empty_cache()
  850. param_dict = dict(model.named_parameters())
  851. stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
  852. for quant_param_name in quant_state_dict:
  853. non_stacked_param_name = quant_param_name
  854. shard_index = 0
  855. for shard_name, (
  856. weight_name, index
  857. ) in model.bitsandbytes_stacked_params_mapping.items():
  858. if shard_name in quant_param_name:
  859. shard_index = index
  860. quant_param_name = quant_param_name.replace(
  861. shard_name, weight_name)
  862. break
  863. if quant_param_name not in param_dict:
  864. raise ValueError(
  865. f"Parameter {quant_param_name} not found in the model.")
  866. if quant_param_name not in stacked_quant_state_dict:
  867. stacked_quant_state_dict[quant_param_name] = {}
  868. stacked_quant_state_dict[quant_param_name][shard_index] = (
  869. quant_state_dict[non_stacked_param_name])
  870. # save quant_states and offsets as the attributes of the parameters
  871. for param_name, param in param_dict.items():
  872. if param_name in stacked_quant_state_dict:
  873. quant_states = stacked_quant_state_dict[param_name]
  874. set_weight_attrs(param, {"bnb_quant_state": quant_states})
  875. pack_ratio = getattr(param, "pack_factor", -1)
  876. if pack_ratio == -1:
  877. raise ValueError(
  878. f"pack_factor not set for parameter {param_name}.")
  879. num_elements = [0] * len(quant_states)
  880. for seq, quant_state in quant_states.items():
  881. num_elements[seq] = math.prod(
  882. quant_state.shape) // pack_ratio
  883. offsets = np.concatenate(([0], np.cumsum(num_elements)))
  884. set_weight_attrs(param, {"bnb_shard_offsets": offsets})
  885. if load_8bit:
  886. set_weight_attrs(
  887. param, {"matmul_state": [None] * len(quant_states)})
  888. def load_model(self, *, model_config: ModelConfig,
  889. device_config: DeviceConfig,
  890. lora_config: Optional[LoRAConfig],
  891. parallel_config: ParallelConfig,
  892. scheduler_config: SchedulerConfig,
  893. cache_config: CacheConfig) -> nn.Module:
  894. with set_default_torch_dtype(model_config.dtype):
  895. with torch.device(device_config.device):
  896. model = _initialize_model(model_config, self.load_config,
  897. lora_config, cache_config)
  898. self._load_weights(model_config, model)
  899. return model.eval()
  900. class GGUFModelLoader(BaseModelLoader):
  901. """
  902. Model loader that can load GGUF files. This is useful for loading models
  903. that are quantized with GGUF and saved in the GGUF format. This loader
  904. supports loading both full models and sharded models.
  905. """
  906. def __init__(self, load_config: LoadConfig):
  907. super().__init__(load_config)
  908. if load_config.model_loader_extra_config:
  909. raise ValueError(f"Model loader extra config is not supported for "
  910. f"load format {load_config.load_format}")
  911. def _prepare_weights(self, model_name_or_path: str):
  912. if os.path.isfile(model_name_or_path):
  913. return model_name_or_path
  914. else:
  915. raise ValueError(f"{model_name_or_path} is not a file.")
  916. def _get_gguf_weights_map(self, model_config: ModelConfig):
  917. """
  918. GGUF uses this naming convention for their tensors from HF checkpoint:
  919. `blk.N.BB.weight` and `blk.N.BB.bias`
  920. where N signifies the block number of a layer, and BB signifies the
  921. attention/mlp layer components.
  922. See "Standardized tensor names" in
  923. https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
  924. """
  925. config = model_config.hf_config
  926. model_type = config.model_type
  927. # hack: ggufs have a different name than transformers
  928. if model_type == "cohere":
  929. model_type = "command-r"
  930. arch = None
  931. for key, value in gguf.MODEL_ARCH_NAMES.items():
  932. if value == model_type:
  933. arch = key
  934. break
  935. if arch is None:
  936. raise RuntimeError(f"Unknown gguf model_type: {model_type}")
  937. num_layers = config.num_hidden_layers
  938. name_map = gguf.get_tensor_name_map(arch, num_layers)
  939. with torch.device("meta"):
  940. dummy_model = AutoModelForCausalLM.from_config(config)
  941. state_dict = dummy_model.state_dict()
  942. gguf_to_hf_name_map = {}
  943. for hf_name in state_dict:
  944. name, suffix = hf_name.rsplit(".", 1)
  945. gguf_name = name_map.get_name(name)
  946. gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
  947. return gguf_to_hf_name_map
  948. def _get_weights_iterator(
  949. self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
  950. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  951. return gguf_quant_weights_iterator(model_name_or_path,
  952. gguf_to_hf_name_map)
  953. def load_model(self, *, model_config: ModelConfig,
  954. device_config: DeviceConfig,
  955. lora_config: Optional[LoRAConfig],
  956. parallel_config: ParallelConfig,
  957. scheduler_config: SchedulerConfig,
  958. cache_config: CacheConfig) -> nn.Module:
  959. local_model_path = self._prepare_weights(model_config.model)
  960. gguf_weights_map = self._get_gguf_weights_map(model_config)
  961. # we can only know if tie word embeddings after mapping weights
  962. if "lm_head.weight" in get_gguf_extra_tensor_names(
  963. local_model_path, gguf_weights_map):
  964. model_config.hf_config.update({"tie_word_embeddings": True})
  965. with set_default_torch_dtype(model_config.dtype):
  966. with torch.device(device_config.device):
  967. model = _initialize_model(model_config, self.load_config,
  968. lora_config, cache_config)
  969. model.load_weights(
  970. self._get_weights_iterator(local_model_path, gguf_weights_map))
  971. return model
  972. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  973. """Get a model loader based on the load format."""
  974. if isinstance(load_config.load_format, type):
  975. return load_config.load_format(load_config)
  976. if load_config.load_format == LoadFormat.DUMMY:
  977. return DummyModelLoader(load_config)
  978. if load_config.load_format == LoadFormat.TENSORIZER:
  979. return TensorizerLoader(load_config)
  980. if load_config.load_format == LoadFormat.SHARDED_STATE:
  981. return ShardedStateLoader(load_config)
  982. if load_config.load_format == LoadFormat.BITSANDBYTES:
  983. return BitsAndBytesModelLoader(load_config)
  984. if load_config.load_format == LoadFormat.GGUF:
  985. return GGUFModelLoader(load_config)
  986. return DefaultModelLoader(load_config)