1
0

loader.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069
  1. # ruff: noqa: SIM117
  2. import collections
  3. import copy
  4. import fnmatch
  5. import glob
  6. import json
  7. import math
  8. import os
  9. from abc import ABC, abstractmethod
  10. from contextlib import contextmanager
  11. from typing import Any, Dict, Generator, List, Optional, Tuple, Type
  12. import gguf
  13. import huggingface_hub
  14. import numpy as np
  15. import torch
  16. from huggingface_hub import HfApi, hf_hub_download
  17. from loguru import logger
  18. from torch import nn
  19. from transformers import AutoModelForCausalLM, PretrainedConfig
  20. from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
  21. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
  22. DeviceConfig, LoadConfig, LoadFormat,
  23. LoRAConfig, ModelConfig, MultiModalConfig,
  24. ParallelConfig, SchedulerConfig)
  25. from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
  26. from aphrodite.modeling.model_loader.tensorizer import (
  27. TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
  28. serialize_aphrodite_model, tensorizer_weights_iterator)
  29. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  30. set_default_torch_dtype)
  31. from aphrodite.modeling.model_loader.weight_utils import (
  32. download_safetensors_index_file_from_hf, download_weights_from_hf,
  33. filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
  34. get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
  35. initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
  36. safetensors_weights_iterator)
  37. from aphrodite.modeling.models.interfaces import (has_inner_state,
  38. supports_lora,
  39. supports_multimodal)
  40. from aphrodite.modeling.utils import set_weight_attrs
  41. from aphrodite.platforms import current_platform
  42. from aphrodite.quantization.base_config import QuantizationConfig
  43. @contextmanager
  44. def device_loading_context(module: torch.nn.Module,
  45. target_device: torch.device):
  46. if target_device.type == "cpu":
  47. # If target is CPU, no need to move anything
  48. yield module
  49. return
  50. original_device_states: Dict[str, torch.device] = {}
  51. # Store original device states and move parameters to GPU if they're on CPU
  52. for name, p in module.named_parameters():
  53. if p.device.type == "cpu":
  54. original_device_states[name] = p.device
  55. p.data = p.data.to(target_device)
  56. # Parameters already on target device are not touched
  57. try:
  58. yield module
  59. finally:
  60. # Restore parameters to their original devices, ignoring new parameters
  61. pin_memory = is_pin_memory_available()
  62. for name, p in module.named_parameters():
  63. if name in original_device_states:
  64. original_device: torch.device = original_device_states[name]
  65. if original_device.type == "cpu":
  66. # `torch.empty_like` does not support `pin_memory` argument
  67. cpu_data = torch.empty_strided(size=p.data.size(),
  68. stride=p.data.stride(),
  69. dtype=p.data.dtype,
  70. layout=p.data.layout,
  71. device="cpu",
  72. pin_memory=pin_memory)
  73. cpu_data.copy_(p.data)
  74. p.data = cpu_data
  75. else:
  76. p.data = p.data.to(original_device)
  77. # New parameters or parameters already on target device are untouched
  78. def _get_quantization_config(
  79. model_config: ModelConfig,
  80. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  81. """Get the quantization config."""
  82. if model_config.quantization is not None:
  83. quant_config = get_quant_config(model_config, load_config)
  84. if not current_platform.is_tpu():
  85. capability = current_platform.get_device_capability()
  86. capability = capability[0] * 10 + capability[1]
  87. if capability < quant_config.get_min_capability():
  88. raise ValueError(
  89. f"The quantization method {model_config.quantization} "
  90. "is not supported for the current GPU. "
  91. f"Minimum capability: {quant_config.get_min_capability()}. "
  92. f"Current capability: {capability}.")
  93. supported_dtypes = quant_config.get_supported_act_dtypes()
  94. if model_config.dtype not in supported_dtypes:
  95. raise ValueError(
  96. f"{model_config.dtype} is not supported for quantization "
  97. f"method {model_config.quantization}. Supported dtypes: "
  98. f"{supported_dtypes}")
  99. return quant_config
  100. return None
  101. def _get_model_initialization_kwargs(
  102. model_class: Type[nn.Module],
  103. lora_config: Optional[LoRAConfig],
  104. multimodal_config: Optional[MultiModalConfig],
  105. scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
  106. """Get extra kwargs for model initialization."""
  107. extra_kwargs: Dict[str, Any] = {}
  108. if supports_lora(model_class):
  109. # lora_config=None is used to disable LoRA
  110. extra_kwargs["lora_config"] = lora_config
  111. elif lora_config:
  112. raise ValueError(
  113. f"Model {model_class.__name__} does not support LoRA, "
  114. "but LoRA is enabled. Support for this model may "
  115. "be added in the future. If this is important to you, "
  116. "please open an issue on github.")
  117. if supports_multimodal(model_class):
  118. assert multimodal_config is not None
  119. extra_kwargs["multimodal_config"] = multimodal_config
  120. if has_inner_state(model_class) and scheduler_config:
  121. extra_kwargs["scheduler_config"] = scheduler_config
  122. return extra_kwargs
  123. def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
  124. cache_config: Optional[CacheConfig],
  125. quant_config: Optional[QuantizationConfig], *,
  126. lora_config: Optional[LoRAConfig],
  127. multimodal_config: Optional[MultiModalConfig],
  128. scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
  129. extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
  130. multimodal_config,
  131. scheduler_config)
  132. return model_class(config=hf_config,
  133. cache_config=cache_config,
  134. quant_config=quant_config,
  135. **extra_kwargs)
  136. def _initialize_model(
  137. model_config: ModelConfig,
  138. load_config: LoadConfig,
  139. lora_config: Optional[LoRAConfig],
  140. cache_config: CacheConfig,
  141. scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
  142. """Initialize a model with the given configurations."""
  143. model_class, _ = get_model_architecture(model_config)
  144. return build_model(
  145. model_class,
  146. model_config.hf_config,
  147. cache_config=cache_config,
  148. quant_config=_get_quantization_config(model_config, load_config),
  149. lora_config=lora_config,
  150. multimodal_config=model_config.multimodal_config,
  151. scheduler_config=scheduler_config,
  152. )
  153. class BaseModelLoader(ABC):
  154. """Base class for model loaders."""
  155. def __init__(self, load_config: LoadConfig):
  156. self.load_config = load_config
  157. @abstractmethod
  158. def load_model(self, *, model_config: ModelConfig,
  159. device_config: DeviceConfig,
  160. lora_config: Optional[LoRAConfig],
  161. parallel_config: ParallelConfig,
  162. scheduler_config: SchedulerConfig,
  163. cache_config: CacheConfig) -> nn.Module:
  164. """Load a model with the given configurations."""
  165. ...
  166. class DefaultModelLoader(BaseModelLoader):
  167. """Model loader that can load different file types from disk."""
  168. def __init__(self, load_config: LoadConfig):
  169. super().__init__(load_config)
  170. if load_config.model_loader_extra_config:
  171. raise ValueError(f"Model loader extra config is not supported for "
  172. f"load format {load_config.load_format}")
  173. def _maybe_download_from_modelscope(
  174. self, model: str, revision: Optional[str]) -> Optional[str]:
  175. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  176. True.
  177. Returns the path to the downloaded model, or None if the model is not
  178. downloaded from ModelScope."""
  179. if APHRODITE_USE_MODELSCOPE:
  180. # download model from ModelScope hub,
  181. # lazy import so that modelscope is not required for normal use.
  182. # pylint: disable=C.
  183. from modelscope.hub.snapshot_download import snapshot_download
  184. if not os.path.exists(model):
  185. model_path = snapshot_download(
  186. model_id=model,
  187. cache_dir=self.load_config.download_dir,
  188. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  189. revision=revision,
  190. ignore_file_pattern=self.load_config.ignore_patterns,
  191. )
  192. else:
  193. model_path = model
  194. return model_path
  195. return None
  196. def _prepare_weights(self, model_name_or_path: str,
  197. revision: Optional[str],
  198. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  199. """Prepare weights for the model.
  200. If the model is not local, it will be downloaded."""
  201. model_name_or_path = self._maybe_download_from_modelscope(
  202. model_name_or_path, revision) or model_name_or_path
  203. is_local = os.path.isdir(model_name_or_path)
  204. load_format = self.load_config.load_format
  205. use_safetensors = False
  206. index_file = SAFE_WEIGHTS_INDEX_NAME
  207. # Some quantized models use .pt files for storing the weights.
  208. if load_format == LoadFormat.AUTO:
  209. allow_patterns = ["*.safetensors", "*.bin"]
  210. elif load_format == LoadFormat.SAFETENSORS:
  211. use_safetensors = True
  212. allow_patterns = ["*.safetensors"]
  213. elif load_format == LoadFormat.MISTRAL:
  214. use_safetensors = True
  215. allow_patterns = ["consolidated*.safetensors"]
  216. index_file = "consolidated.safetensors.index.json"
  217. elif load_format == LoadFormat.PT:
  218. allow_patterns = ["*.pt"]
  219. elif load_format == LoadFormat.NPCACHE:
  220. allow_patterns = ["*.bin"]
  221. else:
  222. raise ValueError(f"Unknown load_format: {load_format}")
  223. if fall_back_to_pt:
  224. allow_patterns += ["*.pt"]
  225. if not is_local:
  226. hf_folder = download_weights_from_hf(
  227. model_name_or_path,
  228. self.load_config.download_dir,
  229. allow_patterns,
  230. revision,
  231. ignore_patterns=self.load_config.ignore_patterns,
  232. )
  233. else:
  234. hf_folder = model_name_or_path
  235. hf_weights_files: List[str] = []
  236. for pattern in allow_patterns:
  237. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  238. if len(hf_weights_files) > 0:
  239. if pattern == "*.safetensors":
  240. use_safetensors = True
  241. break
  242. if use_safetensors:
  243. # For models like Mistral-7B-Instruct-v0.3
  244. # there are both sharded safetensors files and a consolidated
  245. # safetensors file. Using both breaks.
  246. # Here, we download the `model.safetensors.index.json` and filter
  247. # any files not found in the index.
  248. if not is_local:
  249. download_safetensors_index_file_from_hf(
  250. model_name_or_path, index_file,
  251. self.load_config.download_dir, revision)
  252. hf_weights_files = filter_duplicate_safetensors_files(
  253. hf_weights_files, hf_folder, index_file)
  254. else:
  255. hf_weights_files = filter_files_not_needed_for_inference(
  256. hf_weights_files)
  257. if len(hf_weights_files) == 0:
  258. raise RuntimeError(
  259. f"Cannot find any model weights with `{model_name_or_path}`")
  260. return hf_folder, hf_weights_files, use_safetensors
  261. def _get_weights_iterator(
  262. self, model_name_or_path: str, revision: Optional[str],
  263. fall_back_to_pt: bool
  264. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], int]:
  265. """Get an iterator for the model weights based on the load format."""
  266. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  267. model_name_or_path, revision, fall_back_to_pt)
  268. est_weight_bytes = sum(os.path.getsize(f)
  269. for f in hf_weights_files)
  270. if self.load_config.load_format == LoadFormat.NPCACHE:
  271. # Currently np_cache only support *.bin checkpoints
  272. assert use_safetensors is False
  273. weights_iterator = np_cache_weights_iterator(
  274. model_name_or_path, self.load_config.download_dir, hf_folder,
  275. hf_weights_files)
  276. elif use_safetensors:
  277. weights_iterator = safetensors_weights_iterator(hf_weights_files)
  278. else:
  279. weights_iterator = pt_weights_iterator(hf_weights_files)
  280. if current_platform.is_tpu():
  281. # In PyTorch XLA, we should call `xm.mark_step` frequently so that
  282. # not too many ops are accumulated in the XLA program.
  283. import torch_xla.core.xla_model as xm
  284. def _xla_weights_iterator(iterator: Generator):
  285. for weights in iterator:
  286. yield weights
  287. xm.mark_step()
  288. weights_iterator = _xla_weights_iterator(weights_iterator)
  289. return weights_iterator, est_weight_bytes
  290. def load_model(self, *, model_config: ModelConfig,
  291. device_config: DeviceConfig,
  292. lora_config: Optional[LoRAConfig],
  293. parallel_config: ParallelConfig,
  294. scheduler_config: SchedulerConfig,
  295. cache_config: CacheConfig) -> nn.Module:
  296. target_device = torch.device(device_config.device)
  297. with set_default_torch_dtype(model_config.dtype):
  298. with target_device:
  299. model = _initialize_model(model_config, self.load_config,
  300. lora_config, cache_config,
  301. scheduler_config)
  302. weights, wgt_bytes = self._get_weights_iterator(model_config.model,
  303. model_config.revision,
  304. fall_back_to_pt=getattr(
  305. model,
  306. "fall_back_to_pt_during_load",
  307. True))
  308. model.load_weights(tensor_progress_bar(weights, wgt_bytes,
  309. "Loading model weights..."))
  310. for _, module in model.named_modules():
  311. quant_method = getattr(module, "quant_method", None)
  312. if quant_method is not None:
  313. # When quant methods need to process weights after loading
  314. # (for repacking, quantizing, etc), they expect parameters
  315. # to be on the global target device. This scope is for the
  316. # case where cpu offloading is used, where we will move the
  317. # parameters onto device for processing and back off after.
  318. with device_loading_context(module, target_device):
  319. quant_method.process_weights_after_loading(module)
  320. return model.eval()
  321. class DummyModelLoader(BaseModelLoader):
  322. """Model loader that will set model weights to random values."""
  323. def __init__(self, load_config: LoadConfig):
  324. super().__init__(load_config)
  325. if load_config.model_loader_extra_config:
  326. raise ValueError(f"Model loader extra config is not supported for "
  327. f"load format {load_config.load_format}")
  328. def load_model(self, *, model_config: ModelConfig,
  329. device_config: DeviceConfig,
  330. lora_config: Optional[LoRAConfig],
  331. parallel_config: ParallelConfig,
  332. scheduler_config: SchedulerConfig,
  333. cache_config: CacheConfig) -> nn.Module:
  334. with set_default_torch_dtype(model_config.dtype):
  335. with torch.device(device_config.device):
  336. model = _initialize_model(model_config, self.load_config,
  337. lora_config, cache_config,
  338. scheduler_config)
  339. # NOTE: For accurate performance evaluation, we assign
  340. # random values to the weights.
  341. initialize_dummy_weights(model)
  342. return model.eval()
  343. class TensorizerLoader(BaseModelLoader):
  344. """Model loader using CoreWeave's tensorizer library."""
  345. def __init__(self, load_config: LoadConfig):
  346. super().__init__(load_config)
  347. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  348. self.tensorizer_config = load_config.model_loader_extra_config
  349. else:
  350. self.tensorizer_config = TensorizerConfig(
  351. **load_config.model_loader_extra_config)
  352. def _verify_config(self, model_config: ModelConfig,
  353. parallel_config: ParallelConfig):
  354. self.tensorizer_config.verify_with_model_config(model_config)
  355. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  356. def _get_weights_iterator(
  357. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  358. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  359. return tensorizer_weights_iterator(tensorizer_args)
  360. def _load_model_serialized_cpu(
  361. self,
  362. model_config: ModelConfig,
  363. device_config: DeviceConfig,
  364. lora_config: Optional[LoRAConfig],
  365. cache_config: CacheConfig,
  366. ) -> nn.Module:
  367. """Load a serialized model with tensorizer to the CPU.
  368. This is only necessary when the model isn't Aphrodite-tensorized (see
  369. examples/tensorize_aphrodite_model.py) This should still be faster than
  370. default HuggingFace loading, but will be slower than loading a
  371. Aphrodite-tensorized model.
  372. """
  373. with set_default_torch_dtype(model_config.dtype):
  374. with torch.device(device_config.device):
  375. model = _initialize_model(model_config, self.load_config,
  376. lora_config, cache_config)
  377. model.load_weights(self._get_weights_iterator())
  378. return model.eval()
  379. def _load_model_serialized(
  380. self,
  381. model_config: ModelConfig,
  382. device_config: DeviceConfig,
  383. lora_config: Optional[LoRAConfig],
  384. cache_config: CacheConfig,
  385. ) -> nn.Module:
  386. """Load a serialized model with tensorizer.
  387. Expects a Aphrodite-tensorized model. See the
  388. examples/tensorize_aphrodite_model.py example script
  389. for serializing Aphrodite models."""
  390. with set_default_torch_dtype(model_config.dtype):
  391. with torch.device(device_config.device):
  392. model_class = get_model_architecture(model_config)[0]
  393. quant_config = _get_quantization_config(
  394. model_config, self.load_config)
  395. extra_kwargs = _get_model_initialization_kwargs(
  396. model_class, lora_config, model_config.multimodal_config)
  397. extra_kwargs["quant_config"] = quant_config
  398. extra_kwargs["cache_config"] = cache_config
  399. tensorizer_config = copy.copy(self.tensorizer_config)
  400. tensorizer_config.model_class = model_class
  401. tensorizer_config.hf_config = model_config.hf_config
  402. tensorizer_config.dtype = model_config.dtype
  403. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  404. return model.eval()
  405. def load_model(self, *, model_config: ModelConfig,
  406. device_config: DeviceConfig,
  407. lora_config: Optional[LoRAConfig],
  408. parallel_config: ParallelConfig,
  409. scheduler_config: SchedulerConfig,
  410. cache_config: CacheConfig) -> nn.Module:
  411. self._verify_config(model_config, parallel_config)
  412. if parallel_config.tensor_parallel_size > 1:
  413. from aphrodite.distributed import get_tensor_model_parallel_rank
  414. self.tensorizer_config.tensorizer_uri = \
  415. self.tensorizer_config.tensorizer_uri \
  416. % get_tensor_model_parallel_rank()
  417. if is_aphrodite_tensorized(self.tensorizer_config):
  418. return self._load_model_serialized(model_config, device_config,
  419. lora_config, cache_config)
  420. return self._load_model_serialized_cpu(model_config, device_config,
  421. lora_config, cache_config)
  422. @staticmethod
  423. def save_model(
  424. model: torch.nn.Module,
  425. tensorizer_config: TensorizerConfig,
  426. ) -> None:
  427. serialize_aphrodite_model(
  428. model=model,
  429. tensorizer_config=tensorizer_config,
  430. )
  431. class ShardedStateLoader(BaseModelLoader):
  432. """
  433. Model loader that directly loads each worker's model state dict, which
  434. enables a fast load path for large tensor-parallel models where each worker
  435. only needs to read its own shard rather than the entire checkpoint. See
  436. `examples/save_sharded_state.py` for creating a sharded checkpoint.
  437. """
  438. DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
  439. def __init__(self, load_config: LoadConfig):
  440. super().__init__(load_config)
  441. extra_config = ({} if load_config.model_loader_extra_config is None
  442. else load_config.model_loader_extra_config.copy())
  443. self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
  444. if extra_config:
  445. raise ValueError(f"Unexpected extra config keys for load format "
  446. f"{load_config.load_format}: "
  447. f"{load_config.model_loader_extra_config.keys()}")
  448. @staticmethod
  449. def _filter_subtensors(
  450. tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  451. """
  452. Filter out all tensors that share the same memory or a subset of the
  453. memory of another tensor.
  454. """
  455. same_storage_groups: Dict[Any, List[Tuple[
  456. str, torch.Tensor]]] = collections.defaultdict(list)
  457. for key, tensor in tensors.items():
  458. if tensor.numel():
  459. ptr = tensor.untyped_storage().data_ptr()
  460. same_storage_groups[tensor.device, ptr].append((key, tensor))
  461. def get_end_ptr(tensor: torch.Tensor) -> int:
  462. return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  463. result: Dict[str, torch.Tensor] = {}
  464. for group in same_storage_groups.values():
  465. for k, t in group:
  466. a, b = t.data_ptr(), get_end_ptr(t)
  467. for k2, t2 in group:
  468. if not t2.is_contiguous():
  469. continue
  470. a2, b2 = t2.data_ptr(), get_end_ptr(t2)
  471. if a < a2 or b2 < b:
  472. continue
  473. if a2 < a or b < b2 or not t.is_contiguous():
  474. break # t2 covers strictly more memory than t.
  475. if k2 < k:
  476. # Same tensors, keep the one with the smaller key.
  477. break
  478. else:
  479. result[k] = t
  480. return result
  481. def _prepare_weights(self, model_name_or_path: str,
  482. revision: Optional[str]):
  483. if os.path.isdir(model_name_or_path):
  484. return model_name_or_path
  485. else:
  486. allow_patterns = ["*.safetensors"]
  487. return download_weights_from_hf(
  488. model_name_or_path,
  489. self.load_config.download_dir,
  490. allow_patterns,
  491. revision,
  492. ignore_patterns=self.load_config.ignore_patterns,
  493. )
  494. def load_model(self, *, model_config: ModelConfig,
  495. device_config: DeviceConfig,
  496. lora_config: Optional[LoRAConfig],
  497. parallel_config: ParallelConfig,
  498. scheduler_config: SchedulerConfig,
  499. cache_config: CacheConfig) -> nn.Module:
  500. from safetensors.torch import safe_open
  501. from aphrodite.distributed import get_tensor_model_parallel_rank
  502. local_model_path = self._prepare_weights(model_config.model,
  503. model_config.revision)
  504. with set_default_torch_dtype(model_config.dtype):
  505. with torch.device(device_config.device):
  506. model = _initialize_model(model_config, self.load_config,
  507. lora_config, cache_config)
  508. for _, module in model.named_modules():
  509. quant_method = getattr(module, "quant_method", None)
  510. if quant_method is not None:
  511. quant_method.process_weights_after_loading(module)
  512. rank = get_tensor_model_parallel_rank()
  513. pattern = os.path.join(
  514. local_model_path,
  515. self.pattern.format(rank=rank, part="*"),
  516. )
  517. filepaths = glob.glob(pattern)
  518. if not filepaths:
  519. # TODO: support un-sharded checkpoints too
  520. raise ValueError(
  521. f"Could not find checkpoint files '{pattern}', only "
  522. f"pre-sharded checkpoints are currently supported!")
  523. state_dict = self._filter_subtensors(model.state_dict())
  524. for path in filepaths:
  525. with safe_open(path, framework="pt") as f:
  526. for key in f.keys(): # noqa: SIM118
  527. tensor = f.get_tensor(key)
  528. # If loading with LoRA enabled, additional padding may
  529. # be added to certain parameters. We only load into a
  530. # narrowed view of the parameter data.
  531. param_data = state_dict[key].data
  532. param_shape = state_dict[key].shape
  533. for dim, size in enumerate(tensor.shape):
  534. if size < param_shape[dim]:
  535. param_data = param_data.narrow(dim, 0, size)
  536. if tensor.shape != param_shape:
  537. logger.warning("loading tensor of shape "
  538. f"{tensor.shape} into parameter "
  539. f"'{key}' of shape {param_shape}")
  540. param_data.copy_(tensor)
  541. state_dict.pop(key)
  542. if state_dict:
  543. raise ValueError(
  544. f"Missing keys {tuple(state_dict)} in loaded state!")
  545. return model.eval()
  546. @staticmethod
  547. def save_model(
  548. model: torch.nn.Module,
  549. path: str,
  550. pattern: Optional[str] = None,
  551. max_size: Optional[int] = None,
  552. ) -> None:
  553. from safetensors.torch import save_file
  554. from aphrodite.distributed import get_tensor_model_parallel_rank
  555. if pattern is None:
  556. pattern = ShardedStateLoader.DEFAULT_PATTERN
  557. rank = get_tensor_model_parallel_rank()
  558. part_idx = 0
  559. total_size = 0
  560. state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
  561. state_dict_part: Dict[str, torch.Tensor] = {}
  562. for key, tensor in state_dict.items():
  563. param_size = tensor.nelement() * tensor.element_size()
  564. if max_size is not None and total_size + param_size > max_size:
  565. filename = pattern.format(rank=rank, part=part_idx)
  566. save_file(
  567. state_dict_part,
  568. os.path.join(path, filename),
  569. )
  570. part_idx += 1
  571. total_size = 0
  572. state_dict_part = {}
  573. state_dict_part[key] = tensor
  574. total_size += param_size
  575. if len(state_dict_part) > 0:
  576. filename = pattern.format(rank=rank, part=part_idx)
  577. save_file(
  578. state_dict_part,
  579. os.path.join(path, filename),
  580. )
  581. class BitsAndBytesModelLoader(BaseModelLoader):
  582. """Model loader to load model weights with BitAndBytes quantization."""
  583. default_target_modules = [
  584. "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
  585. "o_proj"
  586. ]
  587. possible_config_file_names = ["adapter_config.json"]
  588. def __init__(self, load_config: LoadConfig):
  589. super().__init__(load_config)
  590. # we don't need to quantize the whole model, only the target modules
  591. # that are specified in the adapter config file. If the adapter config
  592. # file is not provided, we will quantize the default modules.
  593. if (not load_config.model_loader_extra_config
  594. or "qlora_adapter_name_or_path"
  595. not in load_config.model_loader_extra_config):
  596. self.target_modules = self.default_target_modules
  597. return
  598. qlora_adapter = load_config.model_loader_extra_config[
  599. "qlora_adapter_name_or_path"]
  600. config_file_path = self._get_config_file(qlora_adapter)
  601. with open(config_file_path, "r") as f:
  602. config = json.load(f)
  603. self.target_modules = config["target_modules"]
  604. def _get_config_file(self, qlora_adapter: str) -> str:
  605. is_local = os.path.isdir(qlora_adapter)
  606. config_file_path = None
  607. if is_local:
  608. for file in self.possible_config_file_names:
  609. config_file_path = os.path.join(qlora_adapter, file)
  610. if os.path.exists(config_file_path):
  611. break
  612. else:
  613. hf_api = HfApi()
  614. repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
  615. for file in self.possible_config_file_names:
  616. if file in repo_files:
  617. config_file_path = hf_hub_download(repo_id=qlora_adapter,
  618. filename=file)
  619. break
  620. if not config_file_path:
  621. raise ValueError(
  622. f"Cannot find adapter config file in {qlora_adapter}")
  623. return config_file_path
  624. def _get_weight_files(
  625. self,
  626. model_name_or_path: str,
  627. allowed_patterns: List[str],
  628. revision: Optional[str] = None) -> Tuple[List[str], str]:
  629. """Retrieve weight files. Download the files if necessary.
  630. Return the weight files and the file pattern."""
  631. is_local = os.path.isdir(model_name_or_path)
  632. if is_local:
  633. for pattern in allowed_patterns:
  634. weight_files = glob.glob(
  635. os.path.join(model_name_or_path, pattern))
  636. if weight_files:
  637. return weight_files, pattern
  638. else:
  639. hf_api = HfApi()
  640. repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
  641. for pattern in allowed_patterns:
  642. matching_files = fnmatch.filter(repo_files, pattern)
  643. if matching_files:
  644. hf_folder = download_weights_from_hf(
  645. model_name_or_path,
  646. self.load_config.download_dir,
  647. [pattern],
  648. revision,
  649. ignore_patterns=self.load_config.ignore_patterns,
  650. )
  651. return glob.glob(os.path.join(hf_folder, pattern)), pattern
  652. raise RuntimeError(
  653. f"No model weights found in: `{model_name_or_path}`")
  654. def _prepare_weights(self, model_name_or_path: str,
  655. revision: Optional[str]) -> Tuple[List[str], bool]:
  656. """Prepare weight files for the model."""
  657. allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
  658. hf_weights_files, matched_pattern = self._get_weight_files(
  659. model_name_or_path, allowed_patterns, revision)
  660. if matched_pattern != "*.safetensors":
  661. hf_weights_files = filter_files_not_needed_for_inference(
  662. hf_weights_files)
  663. if len(hf_weights_files) == 0:
  664. raise RuntimeError(
  665. f"Cannot find any model weights with `{model_name_or_path}`")
  666. return hf_weights_files, matched_pattern == "*.safetensors"
  667. def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
  668. if use_safetensors:
  669. return safetensors_weights_iterator(hf_weights_files)
  670. else:
  671. return pt_weights_iterator(hf_weights_files)
  672. def _get_quantized_weights_iterator(
  673. self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
  674. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
  675. Any]]:
  676. """Get an iterator to the model weights with bitsandbytes quantization,
  677. as well as the quantization state dictionary."""
  678. # only load the bitsandbytes module when needed
  679. try:
  680. import bitsandbytes
  681. from bitsandbytes.functional import QuantState
  682. if bitsandbytes.__version__ < "0.42.0":
  683. raise ImportError("bitsandbytes version is wrong. Please "
  684. "install bitsandbytes>=0.42.0.")
  685. from bitsandbytes.functional import quantize_4bit
  686. except ImportError as err:
  687. raise ImportError("Please install bitsandbytes>=0.42.0 via "
  688. "`pip install bitsandbytes>=0.42.0` to use "
  689. "bitsandbytes quantizer.") from err
  690. hf_weights_files, use_safetensors = self._prepare_weights(
  691. model_name_or_path, revision)
  692. quant_state_dict = {}
  693. def quantized_checkpoint() -> Generator:
  694. # First iterate over all quant state weights
  695. weight_iterator = self._hf_weight_iter(hf_weights_files,
  696. use_safetensors)
  697. temp_state_dict = {}
  698. for weight_name, weight_tensor in weight_iterator:
  699. if weight_name.endswith(".weight"):
  700. continue
  701. # TODO: only nf4 quantization is supported for now
  702. if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
  703. raise NotImplementedError(
  704. "Only bitsandbytes_nf4 quantization"
  705. f"is supported for now. {weight_name} is fp4 quantized"
  706. )
  707. temp_state_dict[weight_name] = weight_tensor
  708. # Closure to parse quant_state for each prequant weight
  709. def _parse_quant_state(param_name: str,
  710. temp_state_dict: Dict) -> QuantState:
  711. quant_state = {}
  712. for k in temp_state_dict:
  713. if param_name + "." in k:
  714. quant_state[k] = temp_state_dict[k]
  715. # bitsandbytes library requires
  716. # weight.quant_state.bitsandbytes__nf4 in CPU
  717. quant_state[param_name +
  718. ".quant_state.bitsandbytes__nf4"] = quant_state[
  719. param_name +
  720. ".quant_state.bitsandbytes__nf4"].cpu().data
  721. return QuantState.from_dict(quant_state, device="cuda")
  722. # Second iterate over all prequant and normal weights
  723. # pre quantized weights would have a quant_state
  724. for weight_name, weight_tensor in self._hf_weight_iter(
  725. hf_weights_files, use_safetensors):
  726. # Filter out all weights whose suffix is not ".weight"
  727. if not weight_name.endswith(".weight"):
  728. continue
  729. if weight_name + ".quant_state.bitsandbytes__nf4" \
  730. in temp_state_dict:
  731. quant_state = _parse_quant_state(weight_name,
  732. temp_state_dict)
  733. weight_name = weight_name.replace(".weight", ".qweight")
  734. quant_state_dict[weight_name] = quant_state
  735. yield weight_name.replace(".weight",
  736. ".qweight"), weight_tensor
  737. else:
  738. yield weight_name, weight_tensor
  739. def generator() -> Generator:
  740. for weight_name, weight_tensor in self._hf_weight_iter(
  741. hf_weights_files, use_safetensors):
  742. if any(target_module in weight_name
  743. for target_module in self.target_modules):
  744. weight_name = weight_name.replace(".weight", ".qweight")
  745. # bitsandbytes requires data in GPU
  746. loaded_weight = weight_tensor.cuda().data
  747. with set_default_torch_dtype(torch.float32):
  748. processed_weight, quant_state = quantize_4bit(
  749. loaded_weight,
  750. compress_statistics=True,
  751. quant_type="nf4")
  752. quant_state_dict[weight_name] = quant_state
  753. else:
  754. processed_weight = weight_tensor
  755. yield weight_name, processed_weight
  756. if pre_quant:
  757. return quantized_checkpoint(), quant_state_dict
  758. return generator(), quant_state_dict
  759. def _load_weights(self, model_config: ModelConfig,
  760. model: nn.Module) -> None:
  761. if not hasattr(model, 'load_weights'):
  762. raise AttributeError(
  763. "The required method 'load_weights' is not defined in class"
  764. f" {type(self).__name__}.")
  765. if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
  766. raise AttributeError(
  767. f"Model {type(self).__name__} does not support BitsAndBytes "
  768. "quantization yet.")
  769. logger.info("Loading weights with BitsAndBytes quantization. "
  770. "This May take a while ...")
  771. is_quantized_checkpoint = False
  772. quant_config = getattr(model_config.hf_config, "quantization_config",
  773. None)
  774. if quant_config is not None and quant_config.get(
  775. 'quant_method') == "bitsandbytes":
  776. is_quantized_checkpoint = True
  777. qweight_iterator, quant_state_dict = \
  778. self._get_quantized_weights_iterator(
  779. model_config.model, model_config.revision, is_quantized_checkpoint)
  780. model.load_weights(qweight_iterator)
  781. torch.cuda.empty_cache()
  782. param_dict = dict(model.named_parameters())
  783. stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
  784. for quant_param_name in quant_state_dict:
  785. non_stacked_param_name = quant_param_name
  786. shard_index = 0
  787. for shard_name, (
  788. weight_name, index
  789. ) in model.bitsandbytes_stacked_params_mapping.items():
  790. if shard_name in quant_param_name:
  791. shard_index = index
  792. quant_param_name = quant_param_name.replace(
  793. shard_name, weight_name)
  794. break
  795. if quant_param_name not in param_dict:
  796. raise ValueError(
  797. f"Parameter {quant_param_name} not found in the model.")
  798. if quant_param_name not in stacked_quant_state_dict:
  799. stacked_quant_state_dict[quant_param_name] = {}
  800. stacked_quant_state_dict[quant_param_name][shard_index] = (
  801. quant_state_dict[non_stacked_param_name])
  802. # save quant_states and offsets as the attributes of the parameters
  803. for param_name, param in param_dict.items():
  804. if param_name in stacked_quant_state_dict:
  805. quant_states = stacked_quant_state_dict[param_name]
  806. set_weight_attrs(param, {"bnb_quant_state": quant_states})
  807. pack_ratio = getattr(param, "pack_factor", -1)
  808. if pack_ratio == -1:
  809. raise ValueError(
  810. f"pack_factor not set for parameter {param_name}.")
  811. num_elements = [0] * len(quant_states)
  812. for seq, quant_state in quant_states.items():
  813. num_elements[seq] = math.prod(
  814. quant_state.shape) // pack_ratio
  815. offsets = np.concatenate(([0], np.cumsum(num_elements)))
  816. set_weight_attrs(param, {"bnb_shard_offsets": offsets})
  817. def load_model(self, *, model_config: ModelConfig,
  818. device_config: DeviceConfig,
  819. lora_config: Optional[LoRAConfig],
  820. parallel_config: ParallelConfig,
  821. scheduler_config: SchedulerConfig,
  822. cache_config: CacheConfig) -> nn.Module:
  823. with set_default_torch_dtype(model_config.dtype):
  824. with torch.device(device_config.device):
  825. model = _initialize_model(model_config, self.load_config,
  826. lora_config, cache_config)
  827. self._load_weights(model_config, model)
  828. return model.eval()
  829. class GGUFModelLoader(BaseModelLoader):
  830. """
  831. Model loader that can load GGUF files. This is useful for loading models
  832. that are quantized with GGUF and saved in the GGUF format. This loader
  833. supports loading both full models and sharded models.
  834. """
  835. def __init__(self, load_config: LoadConfig):
  836. super().__init__(load_config)
  837. if load_config.model_loader_extra_config:
  838. raise ValueError(f"Model loader extra config is not supported for "
  839. f"load format {load_config.load_format}")
  840. def _prepare_weights(self, model_name_or_path: str):
  841. if os.path.isfile(model_name_or_path):
  842. return model_name_or_path
  843. else:
  844. raise ValueError(f"{model_name_or_path} is not a file.")
  845. def _get_gguf_weights_map(self, model_config: ModelConfig):
  846. """
  847. GGUF uses this naming convention for their tensors from HF checkpoint:
  848. `blk.N.BB.weight` and `blk.N.BB.bias`
  849. where N signifies the block number of a layer, and BB signifies the
  850. attention/mlp layer components.
  851. See "Standardized tensor names" in
  852. https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
  853. """
  854. config = model_config.hf_config
  855. model_type = config.model_type
  856. # hack: ggufs have a different name than transformers
  857. if model_type == "cohere":
  858. model_type = "command-r"
  859. arch = None
  860. for key, value in gguf.MODEL_ARCH_NAMES.items():
  861. if value == model_type:
  862. arch = key
  863. break
  864. if arch is None:
  865. raise RuntimeError(f"Unknown gguf model_type: {model_type}")
  866. num_layers = config.num_hidden_layers
  867. name_map = gguf.get_tensor_name_map(arch, num_layers)
  868. with torch.device("meta"):
  869. dummy_model = AutoModelForCausalLM.from_config(config)
  870. state_dict = dummy_model.state_dict()
  871. gguf_to_hf_name_map = {}
  872. for hf_name in state_dict:
  873. name, suffix = hf_name.rsplit(".", 1)
  874. gguf_name = name_map.get_name(name)
  875. gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
  876. return gguf_to_hf_name_map
  877. def _get_weights_iterator(
  878. self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
  879. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  880. return gguf_quant_weights_iterator(model_name_or_path,
  881. gguf_to_hf_name_map)
  882. def load_model(self, *, model_config: ModelConfig,
  883. device_config: DeviceConfig,
  884. lora_config: Optional[LoRAConfig],
  885. parallel_config: ParallelConfig,
  886. scheduler_config: SchedulerConfig,
  887. cache_config: CacheConfig) -> nn.Module:
  888. local_model_path = self._prepare_weights(model_config.model)
  889. gguf_weights_map = self._get_gguf_weights_map(model_config)
  890. # we can only know if tie word embeddings after mapping weights
  891. if "lm_head.weight" in get_gguf_extra_tensor_names(
  892. local_model_path, gguf_weights_map):
  893. model_config.hf_config.update({"tie_word_embeddings": True})
  894. with set_default_torch_dtype(model_config.dtype):
  895. with torch.device(device_config.device):
  896. model = _initialize_model(model_config, self.load_config,
  897. lora_config, cache_config)
  898. model.load_weights(
  899. self._get_weights_iterator(local_model_path, gguf_weights_map))
  900. return model
  901. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  902. """Get a model loader based on the load format."""
  903. if isinstance(load_config.load_format, type):
  904. return load_config.load_format(load_config)
  905. if load_config.load_format == LoadFormat.DUMMY:
  906. return DummyModelLoader(load_config)
  907. if load_config.load_format == LoadFormat.TENSORIZER:
  908. return TensorizerLoader(load_config)
  909. if load_config.load_format == LoadFormat.SHARDED_STATE:
  910. return ShardedStateLoader(load_config)
  911. if load_config.load_format == LoadFormat.BITSANDBYTES:
  912. return BitsAndBytesModelLoader(load_config)
  913. if load_config.load_format == LoadFormat.GGUF:
  914. return GGUFModelLoader(load_config)
  915. return DefaultModelLoader(load_config)