loader.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. # ruff: noqa: SIM117
  2. import collections
  3. import copy
  4. import fnmatch
  5. import glob
  6. import json
  7. import math
  8. import os
  9. from abc import ABC, abstractmethod
  10. from typing import Any, Dict, Generator, List, Optional, Tuple, Type
  11. import huggingface_hub
  12. import numpy as np
  13. import torch
  14. from huggingface_hub import HfApi, hf_hub_download
  15. from loguru import logger
  16. from torch import nn
  17. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
  18. DeviceConfig, LoadConfig, LoadFormat,
  19. LoRAConfig, ModelConfig, MultiModalConfig,
  20. ParallelConfig, SchedulerConfig)
  21. from aphrodite.common.utils import is_tpu
  22. from aphrodite.modeling.model_loader.tensorizer import (
  23. TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
  24. serialize_aphrodite_model, tensorizer_weights_iterator)
  25. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  26. set_default_torch_dtype)
  27. from aphrodite.modeling.model_loader.weight_utils import (
  28. download_safetensors_index_file_from_hf, download_weights_from_hf,
  29. filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
  30. get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
  31. pt_weights_iterator, safetensors_weights_iterator)
  32. from aphrodite.modeling.models.interfaces import (has_inner_state,
  33. supports_lora,
  34. supports_vision)
  35. from aphrodite.modeling.utils import set_weight_attrs
  36. from aphrodite.platforms import current_platform
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. def _get_quantization_config(
  39. model_config: ModelConfig,
  40. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  41. """Get the quantization config."""
  42. if model_config.quantization is not None:
  43. quant_config = get_quant_config(model_config, load_config)
  44. capability = current_platform.get_device_capability()
  45. capability = capability[0] * 10 + capability[1]
  46. if capability < quant_config.get_min_capability():
  47. raise ValueError(
  48. f"The quantization method {model_config.quantization} is not "
  49. "supported for the current GPU. "
  50. f"Minimum capability: {quant_config.get_min_capability()}. "
  51. f"Current capability: {capability}.")
  52. supported_dtypes = quant_config.get_supported_act_dtypes()
  53. if model_config.dtype not in supported_dtypes:
  54. raise ValueError(
  55. f"{model_config.dtype} is not supported for quantization "
  56. f"method {model_config.quantization}. Supported dtypes: "
  57. f"{supported_dtypes}")
  58. return quant_config
  59. return None
  60. def _get_model_initialization_kwargs(
  61. model_class: Type[nn.Module],
  62. lora_config: Optional[LoRAConfig],
  63. multimodal_config: Optional[MultiModalConfig],
  64. scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
  65. """Get extra kwargs for model initialization."""
  66. extra_kwargs: Dict[str, Any] = {}
  67. if supports_lora(model_class):
  68. # lora_config=None is used to disable LoRA
  69. extra_kwargs["lora_config"] = lora_config
  70. elif lora_config:
  71. raise ValueError(
  72. f"Model {model_class.__name__} does not support LoRA, "
  73. "but LoRA is enabled. Support for this model may "
  74. "be added in the future. If this is important to you, "
  75. "please open an issue on github.")
  76. if supports_vision(model_class):
  77. if multimodal_config is None:
  78. raise ValueError("Provide vision related configurations "
  79. "through LLM entrypoint or engine arguments.")
  80. extra_kwargs["multimodal_config"] = multimodal_config
  81. if has_inner_state(model_class) and scheduler_config:
  82. extra_kwargs["scheduler_config"] = scheduler_config
  83. return extra_kwargs
  84. def _initialize_model(
  85. model_config: ModelConfig,
  86. load_config: LoadConfig,
  87. lora_config: Optional[LoRAConfig],
  88. multimodal_config: Optional[MultiModalConfig],
  89. cache_config: CacheConfig,
  90. scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
  91. """Initialize a model with the given configurations."""
  92. model_class = get_model_architecture(model_config)[0]
  93. quant_config = _get_quantization_config(model_config, load_config)
  94. return model_class(config=model_config.hf_config,
  95. cache_config=cache_config,
  96. quant_config=quant_config,
  97. **_get_model_initialization_kwargs(
  98. model_class, lora_config, multimodal_config,
  99. scheduler_config))
  100. class BaseModelLoader(ABC):
  101. """Base class for model loaders."""
  102. def __init__(self, load_config: LoadConfig):
  103. self.load_config = load_config
  104. @abstractmethod
  105. def load_model(self, *, model_config: ModelConfig,
  106. device_config: DeviceConfig,
  107. lora_config: Optional[LoRAConfig],
  108. multimodal_config: Optional[MultiModalConfig],
  109. parallel_config: ParallelConfig,
  110. scheduler_config: SchedulerConfig,
  111. cache_config: CacheConfig) -> nn.Module:
  112. """Load a model with the given configurations."""
  113. ...
  114. class DefaultModelLoader(BaseModelLoader):
  115. """Model loader that can load different file types from disk."""
  116. def __init__(self, load_config: LoadConfig):
  117. super().__init__(load_config)
  118. if load_config.model_loader_extra_config:
  119. raise ValueError(f"Model loader extra config is not supported for "
  120. f"load format {load_config.load_format}")
  121. def _maybe_download_from_modelscope(
  122. self, model: str, revision: Optional[str]) -> Optional[str]:
  123. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  124. True.
  125. Returns the path to the downloaded model, or None if the model is not
  126. downloaded from ModelScope."""
  127. if APHRODITE_USE_MODELSCOPE:
  128. # download model from ModelScope hub,
  129. # lazy import so that modelscope is not required for normal use.
  130. # pylint: disable=C.
  131. from modelscope.hub.snapshot_download import snapshot_download
  132. if not os.path.exists(model):
  133. model_path = snapshot_download(
  134. model_id=model,
  135. cache_dir=self.load_config.download_dir,
  136. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  137. revision=revision,
  138. )
  139. else:
  140. model_path = model
  141. return model_path
  142. return None
  143. def _prepare_weights(self, model_name_or_path: str,
  144. revision: Optional[str],
  145. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  146. """Prepare weights for the model.
  147. If the model is not local, it will be downloaded."""
  148. model_name_or_path = self._maybe_download_from_modelscope(
  149. model_name_or_path, revision) or model_name_or_path
  150. is_local = os.path.isdir(model_name_or_path)
  151. load_format = self.load_config.load_format
  152. use_safetensors = False
  153. # Some quantized models use .pt files for storing the weights.
  154. if load_format == LoadFormat.AUTO:
  155. allow_patterns = ["*.safetensors", "*.bin"]
  156. elif load_format == LoadFormat.SAFETENSORS:
  157. use_safetensors = True
  158. allow_patterns = ["*.safetensors"]
  159. elif load_format == LoadFormat.PT:
  160. allow_patterns = ["*.pt"]
  161. elif load_format == LoadFormat.NPCACHE:
  162. allow_patterns = ["*.bin"]
  163. else:
  164. raise ValueError(f"Unknown load_format: {load_format}")
  165. if fall_back_to_pt:
  166. allow_patterns += ["*.pt"]
  167. if not is_local:
  168. hf_folder = download_weights_from_hf(model_name_or_path,
  169. self.load_config.download_dir,
  170. allow_patterns, revision)
  171. else:
  172. hf_folder = model_name_or_path
  173. hf_weights_files: List[str] = []
  174. for pattern in allow_patterns:
  175. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  176. if len(hf_weights_files) > 0:
  177. if pattern == "*.safetensors":
  178. use_safetensors = True
  179. break
  180. if use_safetensors:
  181. # For models like Mistral-7B-Instruct-v0.3
  182. # there are both sharded safetensors files and a consolidated
  183. # safetensors file. Using both breaks.
  184. # Here, we download the `model.safetensors.index.json` and filter
  185. # any files not found in the index.
  186. if not is_local:
  187. download_safetensors_index_file_from_hf(
  188. model_name_or_path, self.load_config.download_dir,
  189. revision)
  190. hf_weights_files = filter_duplicate_safetensors_files(
  191. hf_weights_files, hf_folder)
  192. else:
  193. hf_weights_files = filter_files_not_needed_for_inference(
  194. hf_weights_files)
  195. if len(hf_weights_files) == 0:
  196. raise RuntimeError(
  197. f"Cannot find any model weights with `{model_name_or_path}`")
  198. return hf_folder, hf_weights_files, use_safetensors
  199. def _get_weights_iterator(
  200. self, model_name_or_path: str, revision: Optional[str],
  201. fall_back_to_pt: bool
  202. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  203. """Get an iterator for the model weights based on the load format."""
  204. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  205. model_name_or_path, revision, fall_back_to_pt)
  206. if self.load_config.load_format == LoadFormat.NPCACHE:
  207. # Currently np_cache only support *.bin checkpoints
  208. assert use_safetensors is False
  209. weights_iterator = np_cache_weights_iterator(
  210. model_name_or_path, self.load_config.download_dir, hf_folder,
  211. hf_weights_files)
  212. elif use_safetensors:
  213. weights_iterator = safetensors_weights_iterator(hf_weights_files)
  214. else:
  215. weights_iterator = pt_weights_iterator(hf_weights_files)
  216. if is_tpu():
  217. # In PyTorch XLA, we should call `xm.mark_step` frequently so that
  218. # not too many ops are accumulated in the XLA program.
  219. import torch_xla.core.xla_model as xm
  220. def _xla_weights_iterator(iterator: Generator):
  221. for weights in iterator:
  222. yield weights
  223. xm.mark_step()
  224. weights_iterator = _xla_weights_iterator(weights_iterator)
  225. return weights_iterator
  226. def load_model(self, *, model_config: ModelConfig,
  227. device_config: DeviceConfig,
  228. lora_config: Optional[LoRAConfig],
  229. multimodal_config: Optional[MultiModalConfig],
  230. parallel_config: ParallelConfig,
  231. scheduler_config: SchedulerConfig,
  232. cache_config: CacheConfig) -> nn.Module:
  233. with set_default_torch_dtype(model_config.dtype):
  234. with torch.device(device_config.device):
  235. model = _initialize_model(model_config, self.load_config,
  236. lora_config, multimodal_config,
  237. cache_config, scheduler_config)
  238. model.load_weights(
  239. self._get_weights_iterator(model_config.model,
  240. model_config.revision,
  241. fall_back_to_pt=getattr(
  242. model,
  243. "fall_back_to_pt_during_load",
  244. True)), )
  245. for _, module in model.named_modules():
  246. quant_method = getattr(module, "quant_method", None)
  247. if quant_method is not None:
  248. quant_method.process_weights_after_loading(module)
  249. return model.eval()
  250. class DummyModelLoader(BaseModelLoader):
  251. """Model loader that will set model weights to random values."""
  252. def __init__(self, load_config: LoadConfig):
  253. super().__init__(load_config)
  254. if load_config.model_loader_extra_config:
  255. raise ValueError(f"Model loader extra config is not supported for "
  256. f"load format {load_config.load_format}")
  257. def load_model(self, *, model_config: ModelConfig,
  258. device_config: DeviceConfig,
  259. lora_config: Optional[LoRAConfig],
  260. multimodal_config: Optional[MultiModalConfig],
  261. parallel_config: ParallelConfig,
  262. scheduler_config: SchedulerConfig,
  263. cache_config: CacheConfig) -> nn.Module:
  264. with set_default_torch_dtype(model_config.dtype):
  265. with torch.device(device_config.device):
  266. model = _initialize_model(model_config, self.load_config,
  267. lora_config, multimodal_config,
  268. cache_config, scheduler_config)
  269. # NOTE: For accurate performance evaluation, we assign
  270. # random values to the weights.
  271. initialize_dummy_weights(model)
  272. return model.eval()
  273. class TensorizerLoader(BaseModelLoader):
  274. """Model loader using CoreWeave's tensorizer library."""
  275. def __init__(self, load_config: LoadConfig):
  276. super().__init__(load_config)
  277. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  278. self.tensorizer_config = load_config.model_loader_extra_config
  279. else:
  280. self.tensorizer_config = TensorizerConfig(
  281. **load_config.model_loader_extra_config)
  282. def _verify_config(self, model_config: ModelConfig,
  283. parallel_config: ParallelConfig):
  284. self.tensorizer_config.verify_with_model_config(model_config)
  285. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  286. def _get_weights_iterator(
  287. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  288. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  289. return tensorizer_weights_iterator(tensorizer_args)
  290. def _load_model_serialized_cpu(
  291. self,
  292. model_config: ModelConfig,
  293. device_config: DeviceConfig,
  294. lora_config: Optional[LoRAConfig],
  295. multimodal_config: Optional[MultiModalConfig],
  296. cache_config: CacheConfig,
  297. ) -> nn.Module:
  298. """Load a serialized model with tensorizer to the CPU.
  299. This is only necessary when the model isn't Aphrodite-tensorized (see
  300. examples/tensorize_aphrodite_model.py) This should still be faster than
  301. default HuggingFace loading, but will be slower than loading a
  302. Aphrodite-tensorized model.
  303. """
  304. with set_default_torch_dtype(model_config.dtype):
  305. with torch.device(device_config.device):
  306. model = _initialize_model(model_config, self.load_config,
  307. lora_config, multimodal_config,
  308. cache_config)
  309. model.load_weights(self._get_weights_iterator())
  310. return model.eval()
  311. def _load_model_serialized(
  312. self,
  313. model_config: ModelConfig,
  314. device_config: DeviceConfig,
  315. lora_config: Optional[LoRAConfig],
  316. multimodal_config: Optional[MultiModalConfig],
  317. cache_config: CacheConfig,
  318. ) -> nn.Module:
  319. """Load a serialized model with tensorizer.
  320. Expects a Aphrodite-tensorized model. See the
  321. examples/tensorize_aphrodite_model.py example script
  322. for serializing Aphrodite models."""
  323. with set_default_torch_dtype(model_config.dtype):
  324. with torch.device(device_config.device):
  325. model_class = get_model_architecture(model_config)[0]
  326. quant_config = _get_quantization_config(
  327. model_config, self.load_config)
  328. extra_kwargs = _get_model_initialization_kwargs(
  329. model_class, lora_config, multimodal_config)
  330. extra_kwargs["quant_config"] = quant_config
  331. extra_kwargs["cache_config"] = cache_config
  332. tensorizer_config = copy.copy(self.tensorizer_config)
  333. tensorizer_config.model_class = model_class
  334. tensorizer_config.hf_config = model_config.hf_config
  335. tensorizer_config.dtype = model_config.dtype
  336. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  337. return model.eval()
  338. def load_model(self, *, model_config: ModelConfig,
  339. device_config: DeviceConfig,
  340. lora_config: Optional[LoRAConfig],
  341. multimodal_config: Optional[MultiModalConfig],
  342. parallel_config: ParallelConfig,
  343. scheduler_config: SchedulerConfig,
  344. cache_config: CacheConfig) -> nn.Module:
  345. self._verify_config(model_config, parallel_config)
  346. if parallel_config.tensor_parallel_size > 1:
  347. from aphrodite.distributed import get_tensor_model_parallel_rank
  348. self.tensorizer_config.tensorizer_uri = \
  349. self.tensorizer_config.tensorizer_uri \
  350. % get_tensor_model_parallel_rank()
  351. if is_aphrodite_tensorized(self.tensorizer_config):
  352. return self._load_model_serialized(model_config, device_config,
  353. lora_config, multimodal_config,
  354. cache_config)
  355. return self._load_model_serialized_cpu(model_config, device_config,
  356. lora_config, multimodal_config,
  357. cache_config)
  358. @staticmethod
  359. def save_model(
  360. model: torch.nn.Module,
  361. tensorizer_config: TensorizerConfig,
  362. ) -> None:
  363. serialize_aphrodite_model(
  364. model=model,
  365. tensorizer_config=tensorizer_config,
  366. )
  367. class ShardedStateLoader(BaseModelLoader):
  368. """
  369. Model loader that directly loads each worker's model state dict, which
  370. enables a fast load path for large tensor-parallel models where each worker
  371. only needs to read its own shard rather than the entire checkpoint. See
  372. `examples/save_sharded_state.py` for creating a sharded checkpoint.
  373. """
  374. DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
  375. def __init__(self, load_config: LoadConfig):
  376. super().__init__(load_config)
  377. extra_config = ({} if load_config.model_loader_extra_config is None
  378. else load_config.model_loader_extra_config.copy())
  379. self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
  380. if extra_config:
  381. raise ValueError(f"Unexpected extra config keys for load format "
  382. f"{load_config.load_format}: "
  383. f"{load_config.model_loader_extra_config.keys()}")
  384. @staticmethod
  385. def _filter_subtensors(
  386. tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  387. """
  388. Filter out all tensors that share the same memory or a subset of the
  389. memory of another tensor.
  390. """
  391. same_storage_groups: Dict[Any, List[Tuple[
  392. str, torch.Tensor]]] = collections.defaultdict(list)
  393. for key, tensor in tensors.items():
  394. if tensor.numel():
  395. ptr = tensor.untyped_storage().data_ptr()
  396. same_storage_groups[tensor.device, ptr].append((key, tensor))
  397. def get_end_ptr(tensor: torch.Tensor) -> int:
  398. return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  399. result: Dict[str, torch.Tensor] = {}
  400. for group in same_storage_groups.values():
  401. for k, t in group:
  402. a, b = t.data_ptr(), get_end_ptr(t)
  403. for k2, t2 in group:
  404. if not t2.is_contiguous():
  405. continue
  406. a2, b2 = t2.data_ptr(), get_end_ptr(t2)
  407. if a < a2 or b2 < b:
  408. continue
  409. if a2 < a or b < b2 or not t.is_contiguous():
  410. break # t2 covers strictly more memory than t.
  411. if k2 < k:
  412. # Same tensors, keep the one with the smaller key.
  413. break
  414. else:
  415. result[k] = t
  416. return result
  417. def _prepare_weights(self, model_name_or_path: str,
  418. revision: Optional[str]):
  419. if os.path.isdir(model_name_or_path):
  420. return model_name_or_path
  421. else:
  422. allow_patterns = ["*.safetensors"]
  423. return download_weights_from_hf(model_name_or_path,
  424. self.load_config.download_dir,
  425. allow_patterns, revision)
  426. def load_model(self, *, model_config: ModelConfig,
  427. device_config: DeviceConfig,
  428. lora_config: Optional[LoRAConfig],
  429. multimodal_config: Optional[MultiModalConfig],
  430. parallel_config: ParallelConfig,
  431. scheduler_config: SchedulerConfig,
  432. cache_config: CacheConfig) -> nn.Module:
  433. from safetensors.torch import safe_open
  434. from aphrodite.distributed import get_tensor_model_parallel_rank
  435. local_model_path = self._prepare_weights(model_config.model,
  436. model_config.revision)
  437. with set_default_torch_dtype(model_config.dtype):
  438. with torch.device(device_config.device):
  439. model = _initialize_model(model_config, self.load_config,
  440. lora_config, multimodal_config,
  441. cache_config)
  442. rank = get_tensor_model_parallel_rank()
  443. pattern = os.path.join(
  444. local_model_path,
  445. self.pattern.format(rank=rank, part="*"),
  446. )
  447. filepaths = glob.glob(pattern)
  448. if not filepaths:
  449. # TODO: support un-sharded checkpoints too
  450. raise ValueError(
  451. f"Could not find checkpoint files '{pattern}', only "
  452. f"pre-sharded checkpoints are currently supported!")
  453. state_dict = self._filter_subtensors(model.state_dict())
  454. for path in filepaths:
  455. with safe_open(path, framework="pt") as f:
  456. for key in f.keys(): # noqa: SIM118
  457. tensor = f.get_tensor(key)
  458. # If loading with LoRA enabled, additional padding may
  459. # be added to certain parameters. We only load into a
  460. # narrowed view of the parameter data.
  461. param_data = state_dict[key].data
  462. param_shape = state_dict[key].shape
  463. for dim, size in enumerate(tensor.shape):
  464. if size < param_shape[dim]:
  465. param_data = param_data.narrow(dim, 0, size)
  466. if tensor.shape != param_shape:
  467. logger.warning("loading tensor of shape "
  468. f"{tensor.shape} into parameter "
  469. f"'{key}' of shape {param_shape}")
  470. param_data.copy_(tensor)
  471. state_dict.pop(key)
  472. if state_dict:
  473. raise ValueError(
  474. f"Missing keys {tuple(state_dict)} in loaded state!")
  475. return model.eval()
  476. @staticmethod
  477. def save_model(
  478. model: torch.nn.Module,
  479. path: str,
  480. pattern: Optional[str] = None,
  481. max_size: Optional[int] = None,
  482. ) -> None:
  483. from safetensors.torch import save_file
  484. from aphrodite.distributed import get_tensor_model_parallel_rank
  485. if pattern is None:
  486. pattern = ShardedStateLoader.DEFAULT_PATTERN
  487. rank = get_tensor_model_parallel_rank()
  488. part_idx = 0
  489. total_size = 0
  490. state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
  491. state_dict_part: Dict[str, torch.Tensor] = {}
  492. for key, tensor in state_dict.items():
  493. param_size = tensor.nelement() * tensor.element_size()
  494. if max_size is not None and total_size + param_size > max_size:
  495. filename = pattern.format(rank=rank, part=part_idx)
  496. save_file(
  497. state_dict_part,
  498. os.path.join(path, filename),
  499. )
  500. part_idx += 1
  501. total_size = 0
  502. state_dict_part = {}
  503. state_dict_part[key] = tensor
  504. total_size += param_size
  505. if len(state_dict_part) > 0:
  506. filename = pattern.format(rank=rank, part=part_idx)
  507. save_file(
  508. state_dict_part,
  509. os.path.join(path, filename),
  510. )
  511. class BitsAndBytesModelLoader(BaseModelLoader):
  512. """Model loader to load model weights with BitAndBytes quantization."""
  513. default_target_modules = [
  514. "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
  515. "o_proj"
  516. ]
  517. possible_config_file_names = ["adapter_config.json"]
  518. def __init__(self, load_config: LoadConfig):
  519. super().__init__(load_config)
  520. # we don't need to quantize the whole model, only the target modules
  521. # that are specified in the adapter config file. If the adapter config
  522. # file is not provided, we will quantize the default modules.
  523. if (not load_config.model_loader_extra_config
  524. or "qlora_adapter_name_or_path"
  525. not in load_config.model_loader_extra_config):
  526. self.target_modules = self.default_target_modules
  527. return
  528. qlora_adapter = load_config.model_loader_extra_config[
  529. "qlora_adapter_name_or_path"]
  530. config_file_path = self._get_config_file(qlora_adapter)
  531. with open(config_file_path, "r") as f:
  532. config = json.load(f)
  533. self.target_modules = config["target_modules"]
  534. def _get_config_file(self, qlora_adapter: str) -> str:
  535. is_local = os.path.isdir(qlora_adapter)
  536. config_file_path = None
  537. if is_local:
  538. for file in self.possible_config_file_names:
  539. config_file_path = os.path.join(qlora_adapter, file)
  540. if os.path.exists(config_file_path):
  541. break
  542. else:
  543. hf_api = HfApi()
  544. repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
  545. for file in self.possible_config_file_names:
  546. if file in repo_files:
  547. config_file_path = hf_hub_download(repo_id=qlora_adapter,
  548. filename=file)
  549. break
  550. if not config_file_path:
  551. raise ValueError(
  552. f"Cannot find adapter config file in {qlora_adapter}")
  553. return config_file_path
  554. def _get_weight_files(
  555. self,
  556. model_name_or_path: str,
  557. allowed_patterns: List[str],
  558. revision: Optional[str] = None) -> Tuple[List[str], str]:
  559. """Retrieve weight files. Download the files if necessary.
  560. Return the weight files and the file pattern."""
  561. is_local = os.path.isdir(model_name_or_path)
  562. if is_local:
  563. for pattern in allowed_patterns:
  564. weight_files = glob.glob(
  565. os.path.join(model_name_or_path, pattern))
  566. if weight_files:
  567. return weight_files, pattern
  568. else:
  569. hf_api = HfApi()
  570. repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
  571. for pattern in allowed_patterns:
  572. matching_files = fnmatch.filter(repo_files, pattern)
  573. if matching_files:
  574. hf_folder = download_weights_from_hf(
  575. model_name_or_path, self.load_config.download_dir,
  576. [pattern], revision)
  577. return glob.glob(os.path.join(hf_folder, pattern)), pattern
  578. raise RuntimeError(
  579. f"No model weights found in: `{model_name_or_path}`")
  580. def _prepare_weights(self, model_name_or_path: str,
  581. revision: Optional[str]) -> Tuple[List[str], bool]:
  582. """Prepare weight files for the model."""
  583. allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
  584. hf_weights_files, matched_pattern = self._get_weight_files(
  585. model_name_or_path, allowed_patterns, revision)
  586. if matched_pattern != "*.safetensors":
  587. hf_weights_files = filter_files_not_needed_for_inference(
  588. hf_weights_files)
  589. if len(hf_weights_files) == 0:
  590. raise RuntimeError(
  591. f"Cannot find any model weights with `{model_name_or_path}`")
  592. return hf_weights_files, matched_pattern == "*.safetensors"
  593. def _get_quantized_weights_iterator(
  594. self, model_name_or_path: str, revision: Optional[str]
  595. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
  596. Any]]:
  597. """Get an iterator to the model weights with bitsandbytes quantization,
  598. as well as the quantization state dictionary."""
  599. # only load the bitsandbytes module when needed
  600. try:
  601. import bitsandbytes
  602. if bitsandbytes.__version__ < "0.42.0":
  603. raise ImportError("bitsandbytes version is wrong. Please "
  604. "install bitsandbytes>=0.42.0.")
  605. from bitsandbytes.functional import quantize_4bit
  606. except ImportError as err:
  607. raise ImportError("Please install bitsandbytes>=0.42.0 via "
  608. "`pip install bitsandbytes>=0.42.0` to use "
  609. "bitsandbytes quantizer.") from err
  610. hf_weights_files, use_safetensors = self._prepare_weights(
  611. model_name_or_path, revision)
  612. quant_state_dict = {}
  613. if use_safetensors:
  614. weight_iterator = safetensors_weights_iterator(hf_weights_files)
  615. else:
  616. weight_iterator = pt_weights_iterator(hf_weights_files)
  617. def generator():
  618. for weight_name, weight_tensor in weight_iterator:
  619. if any(target_module in weight_name
  620. for target_module in self.target_modules):
  621. weight_name = weight_name.replace(".weight", ".qweight")
  622. # bitsandbytes requires data in GPU
  623. loaded_weight = weight_tensor.cuda().data
  624. with set_default_torch_dtype(torch.float32):
  625. processed_weight, quant_state = quantize_4bit(
  626. loaded_weight,
  627. compress_statistics=True,
  628. quant_type="nf4")
  629. quant_state_dict[weight_name] = quant_state
  630. else:
  631. processed_weight = weight_tensor
  632. yield weight_name, processed_weight
  633. return generator(), quant_state_dict
  634. def _load_weights(self, model_config: ModelConfig,
  635. model: nn.Module) -> None:
  636. if not hasattr(model, 'load_weights'):
  637. raise AttributeError(
  638. "The required method 'load_weights' is not defined in class"
  639. f" {type(self).__name__}.")
  640. if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
  641. raise AttributeError(
  642. f"Model {type(self).__name__} does not support BitsAndBytes "
  643. "quantization yet.")
  644. logger.info("Loading weights with BitsAndBytes quantization. "
  645. "This May take a while ...")
  646. qweight_iterator, quant_state_dict = (
  647. self._get_quantized_weights_iterator(model_config.model,
  648. model_config.revision))
  649. model.load_weights(qweight_iterator)
  650. param_dict = dict(model.named_parameters())
  651. stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
  652. for quant_param_name in quant_state_dict:
  653. non_stacked_param_name = quant_param_name
  654. shard_index = 0
  655. for shard_name, (
  656. weight_name, index
  657. ) in model.bitsandbytes_stacked_params_mapping.items():
  658. if shard_name in quant_param_name:
  659. shard_index = index
  660. quant_param_name = quant_param_name.replace(
  661. shard_name, weight_name)
  662. break
  663. if quant_param_name not in param_dict:
  664. raise ValueError(
  665. f"Parameter {quant_param_name} not found in the model.")
  666. if quant_param_name not in stacked_quant_state_dict:
  667. stacked_quant_state_dict[quant_param_name] = {}
  668. stacked_quant_state_dict[quant_param_name][shard_index] = (
  669. quant_state_dict[non_stacked_param_name])
  670. # save quant_states and offsets as the attributes of the parameters
  671. for param_name, param in param_dict.items():
  672. if param_name in stacked_quant_state_dict:
  673. quant_states = stacked_quant_state_dict[param_name]
  674. set_weight_attrs(param, {"bnb_quant_state": quant_states})
  675. pack_ratio = getattr(param, "pack_factor", -1)
  676. if pack_ratio == -1:
  677. raise ValueError(
  678. f"pack_factor not set for parameter {param_name}.")
  679. num_elements = [0] * len(quant_states)
  680. for seq, quant_state in enumerate(quant_states.items()):
  681. num_elements[seq] = math.prod(
  682. quant_state[1].shape) // pack_ratio
  683. offsets = np.concatenate(([0], np.cumsum(num_elements)))
  684. set_weight_attrs(param, {"bnb_shard_offsets": offsets})
  685. def load_model(self, *, model_config: ModelConfig,
  686. device_config: DeviceConfig,
  687. lora_config: Optional[LoRAConfig],
  688. multimodal_config: Optional[MultiModalConfig],
  689. parallel_config: ParallelConfig,
  690. scheduler_config: SchedulerConfig,
  691. cache_config: CacheConfig) -> nn.Module:
  692. with set_default_torch_dtype(model_config.dtype):
  693. with torch.device(device_config.device):
  694. model = _initialize_model(model_config, self.load_config,
  695. lora_config, multimodal_config,
  696. cache_config)
  697. self._load_weights(model_config, model)
  698. return model.eval()
  699. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  700. """Get a model loader based on the load format."""
  701. if isinstance(load_config.load_format, type):
  702. return load_config.load_format(load_config)
  703. if load_config.load_format == LoadFormat.DUMMY:
  704. return DummyModelLoader(load_config)
  705. if load_config.load_format == LoadFormat.TENSORIZER:
  706. return TensorizerLoader(load_config)
  707. if load_config.load_format == LoadFormat.SHARDED_STATE:
  708. return ShardedStateLoader(load_config)
  709. if load_config.load_format == LoadFormat.BITSANDBYTES:
  710. return BitsAndBytesModelLoader(load_config)
  711. return DefaultModelLoader(load_config)