1
0

loader.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055
  1. # ruff: noqa: SIM117
  2. import collections
  3. import copy
  4. import fnmatch
  5. import glob
  6. import json
  7. import math
  8. import os
  9. from abc import ABC, abstractmethod
  10. from contextlib import contextmanager
  11. from typing import Any, Dict, Generator, List, Optional, Tuple, Type
  12. import gguf
  13. import huggingface_hub
  14. import numpy as np
  15. import torch
  16. from huggingface_hub import HfApi, hf_hub_download
  17. from loguru import logger
  18. from torch import nn
  19. from transformers import AutoModelForCausalLM, PretrainedConfig
  20. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
  21. DeviceConfig, LoadConfig, LoadFormat,
  22. LoRAConfig, ModelConfig, MultiModalConfig,
  23. ParallelConfig, SchedulerConfig)
  24. from aphrodite.common.utils import is_pin_memory_available
  25. from aphrodite.modeling.model_loader.tensorizer import (
  26. TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
  27. serialize_aphrodite_model, tensorizer_weights_iterator)
  28. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  29. set_default_torch_dtype)
  30. from aphrodite.modeling.model_loader.weight_utils import (
  31. download_safetensors_index_file_from_hf, download_weights_from_hf,
  32. filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
  33. get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
  34. initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
  35. safetensors_weights_iterator)
  36. from aphrodite.modeling.models.interfaces import (has_inner_state,
  37. supports_lora,
  38. supports_multimodal)
  39. from aphrodite.modeling.utils import set_weight_attrs
  40. from aphrodite.platforms import current_platform
  41. from aphrodite.quantization.base_config import QuantizationConfig
  42. @contextmanager
  43. def device_loading_context(module: torch.nn.Module,
  44. target_device: torch.device):
  45. if target_device.type == "cpu":
  46. # If target is CPU, no need to move anything
  47. yield module
  48. return
  49. original_device_states: Dict[str, torch.device] = {}
  50. # Store original device states and move parameters to GPU if they're on CPU
  51. for name, p in module.named_parameters():
  52. if p.device.type == "cpu":
  53. original_device_states[name] = p.device
  54. p.data = p.data.to(target_device)
  55. # Parameters already on target device are not touched
  56. try:
  57. yield module
  58. finally:
  59. # Restore parameters to their original devices, ignoring new parameters
  60. pin_memory = is_pin_memory_available()
  61. for name, p in module.named_parameters():
  62. if name in original_device_states:
  63. original_device: torch.device = original_device_states[name]
  64. if original_device.type == "cpu":
  65. # `torch.empty_like` does not support `pin_memory` argument
  66. cpu_data = torch.empty_strided(size=p.data.size(),
  67. stride=p.data.stride(),
  68. dtype=p.data.dtype,
  69. layout=p.data.layout,
  70. device="cpu",
  71. pin_memory=pin_memory)
  72. cpu_data.copy_(p.data)
  73. p.data = cpu_data
  74. else:
  75. p.data = p.data.to(original_device)
  76. # New parameters or parameters already on target device are untouched
  77. def _get_quantization_config(
  78. model_config: ModelConfig,
  79. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  80. """Get the quantization config."""
  81. if model_config.quantization is not None:
  82. quant_config = get_quant_config(model_config, load_config)
  83. if not current_platform.is_tpu():
  84. capability = current_platform.get_device_capability()
  85. capability = capability[0] * 10 + capability[1]
  86. if capability < quant_config.get_min_capability():
  87. raise ValueError(
  88. f"The quantization method {model_config.quantization} "
  89. "is not supported for the current GPU. "
  90. f"Minimum capability: {quant_config.get_min_capability()}. "
  91. f"Current capability: {capability}.")
  92. supported_dtypes = quant_config.get_supported_act_dtypes()
  93. if model_config.dtype not in supported_dtypes:
  94. raise ValueError(
  95. f"{model_config.dtype} is not supported for quantization "
  96. f"method {model_config.quantization}. Supported dtypes: "
  97. f"{supported_dtypes}")
  98. return quant_config
  99. return None
  100. def _get_model_initialization_kwargs(
  101. model_class: Type[nn.Module],
  102. lora_config: Optional[LoRAConfig],
  103. multimodal_config: Optional[MultiModalConfig],
  104. scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
  105. """Get extra kwargs for model initialization."""
  106. extra_kwargs: Dict[str, Any] = {}
  107. if supports_lora(model_class):
  108. # lora_config=None is used to disable LoRA
  109. extra_kwargs["lora_config"] = lora_config
  110. elif lora_config:
  111. raise ValueError(
  112. f"Model {model_class.__name__} does not support LoRA, "
  113. "but LoRA is enabled. Support for this model may "
  114. "be added in the future. If this is important to you, "
  115. "please open an issue on github.")
  116. if supports_multimodal(model_class):
  117. assert multimodal_config is not None
  118. extra_kwargs["multimodal_config"] = multimodal_config
  119. if has_inner_state(model_class) and scheduler_config:
  120. extra_kwargs["scheduler_config"] = scheduler_config
  121. return extra_kwargs
  122. def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
  123. cache_config: Optional[CacheConfig],
  124. quant_config: Optional[QuantizationConfig], *,
  125. lora_config: Optional[LoRAConfig],
  126. multimodal_config: Optional[MultiModalConfig],
  127. scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
  128. extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
  129. multimodal_config,
  130. scheduler_config)
  131. return model_class(config=hf_config,
  132. cache_config=cache_config,
  133. quant_config=quant_config,
  134. **extra_kwargs)
  135. def _initialize_model(
  136. model_config: ModelConfig,
  137. load_config: LoadConfig,
  138. lora_config: Optional[LoRAConfig],
  139. cache_config: CacheConfig,
  140. scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
  141. """Initialize a model with the given configurations."""
  142. model_class, _ = get_model_architecture(model_config)
  143. return build_model(
  144. model_class,
  145. model_config.hf_config,
  146. cache_config=cache_config,
  147. quant_config=_get_quantization_config(model_config, load_config),
  148. lora_config=lora_config,
  149. multimodal_config=model_config.multimodal_config,
  150. scheduler_config=scheduler_config,
  151. )
  152. class BaseModelLoader(ABC):
  153. """Base class for model loaders."""
  154. def __init__(self, load_config: LoadConfig):
  155. self.load_config = load_config
  156. @abstractmethod
  157. def load_model(self, *, model_config: ModelConfig,
  158. device_config: DeviceConfig,
  159. lora_config: Optional[LoRAConfig],
  160. parallel_config: ParallelConfig,
  161. scheduler_config: SchedulerConfig,
  162. cache_config: CacheConfig) -> nn.Module:
  163. """Load a model with the given configurations."""
  164. ...
  165. class DefaultModelLoader(BaseModelLoader):
  166. """Model loader that can load different file types from disk."""
  167. def __init__(self, load_config: LoadConfig):
  168. super().__init__(load_config)
  169. if load_config.model_loader_extra_config:
  170. raise ValueError(f"Model loader extra config is not supported for "
  171. f"load format {load_config.load_format}")
  172. def _maybe_download_from_modelscope(
  173. self, model: str, revision: Optional[str]) -> Optional[str]:
  174. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  175. True.
  176. Returns the path to the downloaded model, or None if the model is not
  177. downloaded from ModelScope."""
  178. if APHRODITE_USE_MODELSCOPE:
  179. # download model from ModelScope hub,
  180. # lazy import so that modelscope is not required for normal use.
  181. # pylint: disable=C.
  182. from modelscope.hub.snapshot_download import snapshot_download
  183. if not os.path.exists(model):
  184. model_path = snapshot_download(
  185. model_id=model,
  186. cache_dir=self.load_config.download_dir,
  187. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  188. revision=revision,
  189. ignore_file_pattern=self.load_config.ignore_patterns,
  190. )
  191. else:
  192. model_path = model
  193. return model_path
  194. return None
  195. def _prepare_weights(self, model_name_or_path: str,
  196. revision: Optional[str],
  197. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  198. """Prepare weights for the model.
  199. If the model is not local, it will be downloaded."""
  200. model_name_or_path = self._maybe_download_from_modelscope(
  201. model_name_or_path, revision) or model_name_or_path
  202. is_local = os.path.isdir(model_name_or_path)
  203. load_format = self.load_config.load_format
  204. use_safetensors = False
  205. # Some quantized models use .pt files for storing the weights.
  206. if load_format == LoadFormat.AUTO:
  207. allow_patterns = ["*.safetensors", "*.bin"]
  208. elif load_format == LoadFormat.SAFETENSORS:
  209. use_safetensors = True
  210. allow_patterns = ["*.safetensors"]
  211. elif load_format == LoadFormat.PT:
  212. allow_patterns = ["*.pt"]
  213. elif load_format == LoadFormat.NPCACHE:
  214. allow_patterns = ["*.bin"]
  215. else:
  216. raise ValueError(f"Unknown load_format: {load_format}")
  217. if fall_back_to_pt:
  218. allow_patterns += ["*.pt"]
  219. if not is_local:
  220. hf_folder = download_weights_from_hf(
  221. model_name_or_path,
  222. self.load_config.download_dir,
  223. allow_patterns,
  224. revision,
  225. ignore_patterns=self.load_config.ignore_patterns,
  226. )
  227. else:
  228. hf_folder = model_name_or_path
  229. hf_weights_files: List[str] = []
  230. for pattern in allow_patterns:
  231. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  232. if len(hf_weights_files) > 0:
  233. if pattern == "*.safetensors":
  234. use_safetensors = True
  235. break
  236. if use_safetensors:
  237. # For models like Mistral-7B-Instruct-v0.3
  238. # there are both sharded safetensors files and a consolidated
  239. # safetensors file. Using both breaks.
  240. # Here, we download the `model.safetensors.index.json` and filter
  241. # any files not found in the index.
  242. if not is_local:
  243. download_safetensors_index_file_from_hf(
  244. model_name_or_path, self.load_config.download_dir,
  245. revision)
  246. hf_weights_files = filter_duplicate_safetensors_files(
  247. hf_weights_files, hf_folder)
  248. else:
  249. hf_weights_files = filter_files_not_needed_for_inference(
  250. hf_weights_files)
  251. if len(hf_weights_files) == 0:
  252. raise RuntimeError(
  253. f"Cannot find any model weights with `{model_name_or_path}`")
  254. return hf_folder, hf_weights_files, use_safetensors
  255. def _get_weights_iterator(
  256. self, model_name_or_path: str, revision: Optional[str],
  257. fall_back_to_pt: bool
  258. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  259. """Get an iterator for the model weights based on the load format."""
  260. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  261. model_name_or_path, revision, fall_back_to_pt)
  262. if self.load_config.load_format == LoadFormat.NPCACHE:
  263. # Currently np_cache only support *.bin checkpoints
  264. assert use_safetensors is False
  265. weights_iterator = np_cache_weights_iterator(
  266. model_name_or_path, self.load_config.download_dir, hf_folder,
  267. hf_weights_files)
  268. elif use_safetensors:
  269. weights_iterator = safetensors_weights_iterator(hf_weights_files)
  270. else:
  271. weights_iterator = pt_weights_iterator(hf_weights_files)
  272. if current_platform.is_tpu():
  273. # In PyTorch XLA, we should call `xm.mark_step` frequently so that
  274. # not too many ops are accumulated in the XLA program.
  275. import torch_xla.core.xla_model as xm
  276. def _xla_weights_iterator(iterator: Generator):
  277. for weights in iterator:
  278. yield weights
  279. xm.mark_step()
  280. weights_iterator = _xla_weights_iterator(weights_iterator)
  281. return weights_iterator
  282. def load_model(self, *, model_config: ModelConfig,
  283. device_config: DeviceConfig,
  284. lora_config: Optional[LoRAConfig],
  285. parallel_config: ParallelConfig,
  286. scheduler_config: SchedulerConfig,
  287. cache_config: CacheConfig) -> nn.Module:
  288. target_device = torch.device(device_config.device)
  289. with set_default_torch_dtype(model_config.dtype):
  290. with target_device:
  291. model = _initialize_model(model_config, self.load_config,
  292. lora_config, cache_config,
  293. scheduler_config)
  294. model.load_weights(
  295. self._get_weights_iterator(model_config.model,
  296. model_config.revision,
  297. fall_back_to_pt=getattr(
  298. model,
  299. "fall_back_to_pt_during_load",
  300. True)), )
  301. for _, module in model.named_modules():
  302. quant_method = getattr(module, "quant_method", None)
  303. if quant_method is not None:
  304. # When quant methods need to process weights after loading
  305. # (for repacking, quantizing, etc), they expect parameters
  306. # to be on the global target device. This scope is for the
  307. # case where cpu offloading is used, where we will move the
  308. # parameters onto device for processing and back off after.
  309. with device_loading_context(module, target_device):
  310. quant_method.process_weights_after_loading(module)
  311. return model.eval()
  312. class DummyModelLoader(BaseModelLoader):
  313. """Model loader that will set model weights to random values."""
  314. def __init__(self, load_config: LoadConfig):
  315. super().__init__(load_config)
  316. if load_config.model_loader_extra_config:
  317. raise ValueError(f"Model loader extra config is not supported for "
  318. f"load format {load_config.load_format}")
  319. def load_model(self, *, model_config: ModelConfig,
  320. device_config: DeviceConfig,
  321. lora_config: Optional[LoRAConfig],
  322. parallel_config: ParallelConfig,
  323. scheduler_config: SchedulerConfig,
  324. cache_config: CacheConfig) -> nn.Module:
  325. with set_default_torch_dtype(model_config.dtype):
  326. with torch.device(device_config.device):
  327. model = _initialize_model(model_config, self.load_config,
  328. lora_config, cache_config,
  329. scheduler_config)
  330. # NOTE: For accurate performance evaluation, we assign
  331. # random values to the weights.
  332. initialize_dummy_weights(model)
  333. return model.eval()
  334. class TensorizerLoader(BaseModelLoader):
  335. """Model loader using CoreWeave's tensorizer library."""
  336. def __init__(self, load_config: LoadConfig):
  337. super().__init__(load_config)
  338. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  339. self.tensorizer_config = load_config.model_loader_extra_config
  340. else:
  341. self.tensorizer_config = TensorizerConfig(
  342. **load_config.model_loader_extra_config)
  343. def _verify_config(self, model_config: ModelConfig,
  344. parallel_config: ParallelConfig):
  345. self.tensorizer_config.verify_with_model_config(model_config)
  346. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  347. def _get_weights_iterator(
  348. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  349. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  350. return tensorizer_weights_iterator(tensorizer_args)
  351. def _load_model_serialized_cpu(
  352. self,
  353. model_config: ModelConfig,
  354. device_config: DeviceConfig,
  355. lora_config: Optional[LoRAConfig],
  356. cache_config: CacheConfig,
  357. ) -> nn.Module:
  358. """Load a serialized model with tensorizer to the CPU.
  359. This is only necessary when the model isn't Aphrodite-tensorized (see
  360. examples/tensorize_aphrodite_model.py) This should still be faster than
  361. default HuggingFace loading, but will be slower than loading a
  362. Aphrodite-tensorized model.
  363. """
  364. with set_default_torch_dtype(model_config.dtype):
  365. with torch.device(device_config.device):
  366. model = _initialize_model(model_config, self.load_config,
  367. lora_config, cache_config)
  368. model.load_weights(self._get_weights_iterator())
  369. return model.eval()
  370. def _load_model_serialized(
  371. self,
  372. model_config: ModelConfig,
  373. device_config: DeviceConfig,
  374. lora_config: Optional[LoRAConfig],
  375. cache_config: CacheConfig,
  376. ) -> nn.Module:
  377. """Load a serialized model with tensorizer.
  378. Expects a Aphrodite-tensorized model. See the
  379. examples/tensorize_aphrodite_model.py example script
  380. for serializing Aphrodite models."""
  381. with set_default_torch_dtype(model_config.dtype):
  382. with torch.device(device_config.device):
  383. model_class = get_model_architecture(model_config)[0]
  384. quant_config = _get_quantization_config(
  385. model_config, self.load_config)
  386. extra_kwargs = _get_model_initialization_kwargs(
  387. model_class, lora_config, model_config.multimodal_config)
  388. extra_kwargs["quant_config"] = quant_config
  389. extra_kwargs["cache_config"] = cache_config
  390. tensorizer_config = copy.copy(self.tensorizer_config)
  391. tensorizer_config.model_class = model_class
  392. tensorizer_config.hf_config = model_config.hf_config
  393. tensorizer_config.dtype = model_config.dtype
  394. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  395. return model.eval()
  396. def load_model(self, *, model_config: ModelConfig,
  397. device_config: DeviceConfig,
  398. lora_config: Optional[LoRAConfig],
  399. parallel_config: ParallelConfig,
  400. scheduler_config: SchedulerConfig,
  401. cache_config: CacheConfig) -> nn.Module:
  402. self._verify_config(model_config, parallel_config)
  403. if parallel_config.tensor_parallel_size > 1:
  404. from aphrodite.distributed import get_tensor_model_parallel_rank
  405. self.tensorizer_config.tensorizer_uri = \
  406. self.tensorizer_config.tensorizer_uri \
  407. % get_tensor_model_parallel_rank()
  408. if is_aphrodite_tensorized(self.tensorizer_config):
  409. return self._load_model_serialized(model_config, device_config,
  410. lora_config, cache_config)
  411. return self._load_model_serialized_cpu(model_config, device_config,
  412. lora_config, cache_config)
  413. @staticmethod
  414. def save_model(
  415. model: torch.nn.Module,
  416. tensorizer_config: TensorizerConfig,
  417. ) -> None:
  418. serialize_aphrodite_model(
  419. model=model,
  420. tensorizer_config=tensorizer_config,
  421. )
  422. class ShardedStateLoader(BaseModelLoader):
  423. """
  424. Model loader that directly loads each worker's model state dict, which
  425. enables a fast load path for large tensor-parallel models where each worker
  426. only needs to read its own shard rather than the entire checkpoint. See
  427. `examples/save_sharded_state.py` for creating a sharded checkpoint.
  428. """
  429. DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
  430. def __init__(self, load_config: LoadConfig):
  431. super().__init__(load_config)
  432. extra_config = ({} if load_config.model_loader_extra_config is None
  433. else load_config.model_loader_extra_config.copy())
  434. self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
  435. if extra_config:
  436. raise ValueError(f"Unexpected extra config keys for load format "
  437. f"{load_config.load_format}: "
  438. f"{load_config.model_loader_extra_config.keys()}")
  439. @staticmethod
  440. def _filter_subtensors(
  441. tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  442. """
  443. Filter out all tensors that share the same memory or a subset of the
  444. memory of another tensor.
  445. """
  446. same_storage_groups: Dict[Any, List[Tuple[
  447. str, torch.Tensor]]] = collections.defaultdict(list)
  448. for key, tensor in tensors.items():
  449. if tensor.numel():
  450. ptr = tensor.untyped_storage().data_ptr()
  451. same_storage_groups[tensor.device, ptr].append((key, tensor))
  452. def get_end_ptr(tensor: torch.Tensor) -> int:
  453. return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  454. result: Dict[str, torch.Tensor] = {}
  455. for group in same_storage_groups.values():
  456. for k, t in group:
  457. a, b = t.data_ptr(), get_end_ptr(t)
  458. for k2, t2 in group:
  459. if not t2.is_contiguous():
  460. continue
  461. a2, b2 = t2.data_ptr(), get_end_ptr(t2)
  462. if a < a2 or b2 < b:
  463. continue
  464. if a2 < a or b < b2 or not t.is_contiguous():
  465. break # t2 covers strictly more memory than t.
  466. if k2 < k:
  467. # Same tensors, keep the one with the smaller key.
  468. break
  469. else:
  470. result[k] = t
  471. return result
  472. def _prepare_weights(self, model_name_or_path: str,
  473. revision: Optional[str]):
  474. if os.path.isdir(model_name_or_path):
  475. return model_name_or_path
  476. else:
  477. allow_patterns = ["*.safetensors"]
  478. return download_weights_from_hf(
  479. model_name_or_path,
  480. self.load_config.download_dir,
  481. allow_patterns,
  482. revision,
  483. ignore_patterns=self.load_config.ignore_patterns,
  484. )
  485. def load_model(self, *, model_config: ModelConfig,
  486. device_config: DeviceConfig,
  487. lora_config: Optional[LoRAConfig],
  488. parallel_config: ParallelConfig,
  489. scheduler_config: SchedulerConfig,
  490. cache_config: CacheConfig) -> nn.Module:
  491. from safetensors.torch import safe_open
  492. from aphrodite.distributed import get_tensor_model_parallel_rank
  493. local_model_path = self._prepare_weights(model_config.model,
  494. model_config.revision)
  495. with set_default_torch_dtype(model_config.dtype):
  496. with torch.device(device_config.device):
  497. model = _initialize_model(model_config, self.load_config,
  498. lora_config, cache_config)
  499. rank = get_tensor_model_parallel_rank()
  500. pattern = os.path.join(
  501. local_model_path,
  502. self.pattern.format(rank=rank, part="*"),
  503. )
  504. filepaths = glob.glob(pattern)
  505. if not filepaths:
  506. # TODO: support un-sharded checkpoints too
  507. raise ValueError(
  508. f"Could not find checkpoint files '{pattern}', only "
  509. f"pre-sharded checkpoints are currently supported!")
  510. state_dict = self._filter_subtensors(model.state_dict())
  511. for path in filepaths:
  512. with safe_open(path, framework="pt") as f:
  513. for key in f.keys(): # noqa: SIM118
  514. tensor = f.get_tensor(key)
  515. # If loading with LoRA enabled, additional padding may
  516. # be added to certain parameters. We only load into a
  517. # narrowed view of the parameter data.
  518. param_data = state_dict[key].data
  519. param_shape = state_dict[key].shape
  520. for dim, size in enumerate(tensor.shape):
  521. if size < param_shape[dim]:
  522. param_data = param_data.narrow(dim, 0, size)
  523. if tensor.shape != param_shape:
  524. logger.warning("loading tensor of shape "
  525. f"{tensor.shape} into parameter "
  526. f"'{key}' of shape {param_shape}")
  527. param_data.copy_(tensor)
  528. state_dict.pop(key)
  529. if state_dict:
  530. raise ValueError(
  531. f"Missing keys {tuple(state_dict)} in loaded state!")
  532. return model.eval()
  533. @staticmethod
  534. def save_model(
  535. model: torch.nn.Module,
  536. path: str,
  537. pattern: Optional[str] = None,
  538. max_size: Optional[int] = None,
  539. ) -> None:
  540. from safetensors.torch import save_file
  541. from aphrodite.distributed import get_tensor_model_parallel_rank
  542. if pattern is None:
  543. pattern = ShardedStateLoader.DEFAULT_PATTERN
  544. rank = get_tensor_model_parallel_rank()
  545. part_idx = 0
  546. total_size = 0
  547. state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
  548. state_dict_part: Dict[str, torch.Tensor] = {}
  549. for key, tensor in state_dict.items():
  550. param_size = tensor.nelement() * tensor.element_size()
  551. if max_size is not None and total_size + param_size > max_size:
  552. filename = pattern.format(rank=rank, part=part_idx)
  553. save_file(
  554. state_dict_part,
  555. os.path.join(path, filename),
  556. )
  557. part_idx += 1
  558. total_size = 0
  559. state_dict_part = {}
  560. state_dict_part[key] = tensor
  561. total_size += param_size
  562. if len(state_dict_part) > 0:
  563. filename = pattern.format(rank=rank, part=part_idx)
  564. save_file(
  565. state_dict_part,
  566. os.path.join(path, filename),
  567. )
  568. class BitsAndBytesModelLoader(BaseModelLoader):
  569. """Model loader to load model weights with BitAndBytes quantization."""
  570. default_target_modules = [
  571. "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
  572. "o_proj"
  573. ]
  574. possible_config_file_names = ["adapter_config.json"]
  575. def __init__(self, load_config: LoadConfig):
  576. super().__init__(load_config)
  577. # we don't need to quantize the whole model, only the target modules
  578. # that are specified in the adapter config file. If the adapter config
  579. # file is not provided, we will quantize the default modules.
  580. if (not load_config.model_loader_extra_config
  581. or "qlora_adapter_name_or_path"
  582. not in load_config.model_loader_extra_config):
  583. self.target_modules = self.default_target_modules
  584. return
  585. qlora_adapter = load_config.model_loader_extra_config[
  586. "qlora_adapter_name_or_path"]
  587. config_file_path = self._get_config_file(qlora_adapter)
  588. with open(config_file_path, "r") as f:
  589. config = json.load(f)
  590. self.target_modules = config["target_modules"]
  591. def _get_config_file(self, qlora_adapter: str) -> str:
  592. is_local = os.path.isdir(qlora_adapter)
  593. config_file_path = None
  594. if is_local:
  595. for file in self.possible_config_file_names:
  596. config_file_path = os.path.join(qlora_adapter, file)
  597. if os.path.exists(config_file_path):
  598. break
  599. else:
  600. hf_api = HfApi()
  601. repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
  602. for file in self.possible_config_file_names:
  603. if file in repo_files:
  604. config_file_path = hf_hub_download(repo_id=qlora_adapter,
  605. filename=file)
  606. break
  607. if not config_file_path:
  608. raise ValueError(
  609. f"Cannot find adapter config file in {qlora_adapter}")
  610. return config_file_path
  611. def _get_weight_files(
  612. self,
  613. model_name_or_path: str,
  614. allowed_patterns: List[str],
  615. revision: Optional[str] = None) -> Tuple[List[str], str]:
  616. """Retrieve weight files. Download the files if necessary.
  617. Return the weight files and the file pattern."""
  618. is_local = os.path.isdir(model_name_or_path)
  619. if is_local:
  620. for pattern in allowed_patterns:
  621. weight_files = glob.glob(
  622. os.path.join(model_name_or_path, pattern))
  623. if weight_files:
  624. return weight_files, pattern
  625. else:
  626. hf_api = HfApi()
  627. repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
  628. for pattern in allowed_patterns:
  629. matching_files = fnmatch.filter(repo_files, pattern)
  630. if matching_files:
  631. hf_folder = download_weights_from_hf(
  632. model_name_or_path,
  633. self.load_config.download_dir,
  634. [pattern],
  635. revision,
  636. ignore_patterns=self.load_config.ignore_patterns,
  637. )
  638. return glob.glob(os.path.join(hf_folder, pattern)), pattern
  639. raise RuntimeError(
  640. f"No model weights found in: `{model_name_or_path}`")
  641. def _prepare_weights(self, model_name_or_path: str,
  642. revision: Optional[str]) -> Tuple[List[str], bool]:
  643. """Prepare weight files for the model."""
  644. allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
  645. hf_weights_files, matched_pattern = self._get_weight_files(
  646. model_name_or_path, allowed_patterns, revision)
  647. if matched_pattern != "*.safetensors":
  648. hf_weights_files = filter_files_not_needed_for_inference(
  649. hf_weights_files)
  650. if len(hf_weights_files) == 0:
  651. raise RuntimeError(
  652. f"Cannot find any model weights with `{model_name_or_path}`")
  653. return hf_weights_files, matched_pattern == "*.safetensors"
  654. def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
  655. if use_safetensors:
  656. return safetensors_weights_iterator(hf_weights_files)
  657. else:
  658. return pt_weights_iterator(hf_weights_files)
  659. def _get_quantized_weights_iterator(
  660. self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
  661. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
  662. Any]]:
  663. """Get an iterator to the model weights with bitsandbytes quantization,
  664. as well as the quantization state dictionary."""
  665. # only load the bitsandbytes module when needed
  666. try:
  667. import bitsandbytes
  668. from bitsandbytes.functional import QuantState
  669. if bitsandbytes.__version__ < "0.42.0":
  670. raise ImportError("bitsandbytes version is wrong. Please "
  671. "install bitsandbytes>=0.42.0.")
  672. from bitsandbytes.functional import quantize_4bit
  673. except ImportError as err:
  674. raise ImportError("Please install bitsandbytes>=0.42.0 via "
  675. "`pip install bitsandbytes>=0.42.0` to use "
  676. "bitsandbytes quantizer.") from err
  677. hf_weights_files, use_safetensors = self._prepare_weights(
  678. model_name_or_path, revision)
  679. quant_state_dict = {}
  680. def quantized_checkpoint() -> Generator:
  681. # First iterate over all quant state weights
  682. weight_iterator = self._hf_weight_iter(hf_weights_files,
  683. use_safetensors)
  684. temp_state_dict = {}
  685. for weight_name, weight_tensor in weight_iterator:
  686. if weight_name.endswith(".weight"):
  687. continue
  688. # TODO: only nf4 quantization is supported for now
  689. if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
  690. raise NotImplementedError(
  691. "Only bitsandbytes_nf4 quantization"
  692. f"is supported for now. {weight_name} is fp4 quantized"
  693. )
  694. temp_state_dict[weight_name] = weight_tensor
  695. # Closure to parse quant_state for each prequant weight
  696. def _parse_quant_state(param_name: str,
  697. temp_state_dict: Dict) -> QuantState:
  698. quant_state = {}
  699. for k in temp_state_dict:
  700. if param_name + "." in k:
  701. quant_state[k] = temp_state_dict[k]
  702. # bitsandbytes library requires
  703. # weight.quant_state.bitsandbytes__nf4 in CPU
  704. quant_state[param_name +
  705. ".quant_state.bitsandbytes__nf4"] = quant_state[
  706. param_name +
  707. ".quant_state.bitsandbytes__nf4"].cpu().data
  708. return QuantState.from_dict(quant_state, device="cuda")
  709. # Second iterate over all prequant and normal weights
  710. # pre quantized weights would have a quant_state
  711. for weight_name, weight_tensor in self._hf_weight_iter(
  712. hf_weights_files, use_safetensors):
  713. # Filter out all weights whose suffix is not ".weight"
  714. if not weight_name.endswith(".weight"):
  715. continue
  716. if weight_name + ".quant_state.bitsandbytes__nf4" \
  717. in temp_state_dict:
  718. quant_state = _parse_quant_state(weight_name,
  719. temp_state_dict)
  720. weight_name = weight_name.replace(".weight", ".qweight")
  721. quant_state_dict[weight_name] = quant_state
  722. yield weight_name.replace(".weight",
  723. ".qweight"), weight_tensor
  724. else:
  725. yield weight_name, weight_tensor
  726. def generator() -> Generator:
  727. for weight_name, weight_tensor in self._hf_weight_iter(
  728. hf_weights_files, use_safetensors):
  729. if any(target_module in weight_name
  730. for target_module in self.target_modules):
  731. weight_name = weight_name.replace(".weight", ".qweight")
  732. # bitsandbytes requires data in GPU
  733. loaded_weight = weight_tensor.cuda().data
  734. with set_default_torch_dtype(torch.float32):
  735. processed_weight, quant_state = quantize_4bit(
  736. loaded_weight,
  737. compress_statistics=True,
  738. quant_type="nf4")
  739. quant_state_dict[weight_name] = quant_state
  740. else:
  741. processed_weight = weight_tensor
  742. yield weight_name, processed_weight
  743. if pre_quant:
  744. return quantized_checkpoint(), quant_state_dict
  745. return generator(), quant_state_dict
  746. def _load_weights(self, model_config: ModelConfig,
  747. model: nn.Module) -> None:
  748. if not hasattr(model, 'load_weights'):
  749. raise AttributeError(
  750. "The required method 'load_weights' is not defined in class"
  751. f" {type(self).__name__}.")
  752. if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
  753. raise AttributeError(
  754. f"Model {type(self).__name__} does not support BitsAndBytes "
  755. "quantization yet.")
  756. logger.info("Loading weights with BitsAndBytes quantization. "
  757. "This May take a while ...")
  758. is_quantized_checkpoint = False
  759. quant_config = getattr(model_config.hf_config, "quantization_config",
  760. None)
  761. if quant_config is not None and quant_config.get(
  762. 'quant_method') == "bitsandbytes":
  763. is_quantized_checkpoint = True
  764. qweight_iterator, quant_state_dict = \
  765. self._get_quantized_weights_iterator(
  766. model_config.model, model_config.revision, is_quantized_checkpoint)
  767. model.load_weights(qweight_iterator)
  768. torch.cuda.empty_cache()
  769. param_dict = dict(model.named_parameters())
  770. stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
  771. for quant_param_name in quant_state_dict:
  772. non_stacked_param_name = quant_param_name
  773. shard_index = 0
  774. for shard_name, (
  775. weight_name, index
  776. ) in model.bitsandbytes_stacked_params_mapping.items():
  777. if shard_name in quant_param_name:
  778. shard_index = index
  779. quant_param_name = quant_param_name.replace(
  780. shard_name, weight_name)
  781. break
  782. if quant_param_name not in param_dict:
  783. raise ValueError(
  784. f"Parameter {quant_param_name} not found in the model.")
  785. if quant_param_name not in stacked_quant_state_dict:
  786. stacked_quant_state_dict[quant_param_name] = {}
  787. stacked_quant_state_dict[quant_param_name][shard_index] = (
  788. quant_state_dict[non_stacked_param_name])
  789. # save quant_states and offsets as the attributes of the parameters
  790. for param_name, param in param_dict.items():
  791. if param_name in stacked_quant_state_dict:
  792. quant_states = stacked_quant_state_dict[param_name]
  793. set_weight_attrs(param, {"bnb_quant_state": quant_states})
  794. pack_ratio = getattr(param, "pack_factor", -1)
  795. if pack_ratio == -1:
  796. raise ValueError(
  797. f"pack_factor not set for parameter {param_name}.")
  798. num_elements = [0] * len(quant_states)
  799. for seq, quant_state in quant_states.items():
  800. num_elements[seq] = math.prod(
  801. quant_state.shape) // pack_ratio
  802. offsets = np.concatenate(([0], np.cumsum(num_elements)))
  803. set_weight_attrs(param, {"bnb_shard_offsets": offsets})
  804. def load_model(self, *, model_config: ModelConfig,
  805. device_config: DeviceConfig,
  806. lora_config: Optional[LoRAConfig],
  807. parallel_config: ParallelConfig,
  808. scheduler_config: SchedulerConfig,
  809. cache_config: CacheConfig) -> nn.Module:
  810. with set_default_torch_dtype(model_config.dtype):
  811. with torch.device(device_config.device):
  812. model = _initialize_model(model_config, self.load_config,
  813. lora_config, cache_config)
  814. self._load_weights(model_config, model)
  815. return model.eval()
  816. class GGUFModelLoader(BaseModelLoader):
  817. """
  818. Model loader that can load GGUF files. This is useful for loading models
  819. that are quantized with GGUF and saved in the GGUF format. This loader
  820. supports loading both full models and sharded models.
  821. """
  822. def __init__(self, load_config: LoadConfig):
  823. super().__init__(load_config)
  824. if load_config.model_loader_extra_config:
  825. raise ValueError(f"Model loader extra config is not supported for "
  826. f"load format {load_config.load_format}")
  827. def _prepare_weights(self, model_name_or_path: str):
  828. if os.path.isfile(model_name_or_path):
  829. return model_name_or_path
  830. else:
  831. raise ValueError(f"{model_name_or_path} is not a file.")
  832. def _get_gguf_weights_map(self, model_config: ModelConfig):
  833. """
  834. GGUF uses this naming convention for their tensors from HF checkpoint:
  835. `blk.N.BB.weight` and `blk.N.BB.bias`
  836. where N signifies the block number of a layer, and BB signifies the
  837. attention/mlp layer components.
  838. See "Standardized tensor names" in
  839. https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
  840. """
  841. config = model_config.hf_config
  842. model_type = config.model_type
  843. # hack: ggufs have a different name than transformers
  844. if model_type == "cohere":
  845. model_type = "command-r"
  846. arch = None
  847. for key, value in gguf.MODEL_ARCH_NAMES.items():
  848. if value == model_type:
  849. arch = key
  850. break
  851. if arch is None:
  852. raise RuntimeError(f"Unknown gguf model_type: {model_type}")
  853. num_layers = config.num_hidden_layers
  854. name_map = gguf.get_tensor_name_map(arch, num_layers)
  855. with torch.device("meta"):
  856. dummy_model = AutoModelForCausalLM.from_config(config)
  857. state_dict = dummy_model.state_dict()
  858. gguf_to_hf_name_map = {}
  859. for hf_name in state_dict:
  860. name, suffix = hf_name.rsplit(".", 1)
  861. gguf_name = name_map.get_name(name)
  862. gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
  863. return gguf_to_hf_name_map
  864. def _get_weights_iterator(
  865. self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
  866. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  867. return gguf_quant_weights_iterator(model_name_or_path,
  868. gguf_to_hf_name_map)
  869. def load_model(self, *, model_config: ModelConfig,
  870. device_config: DeviceConfig,
  871. lora_config: Optional[LoRAConfig],
  872. parallel_config: ParallelConfig,
  873. scheduler_config: SchedulerConfig,
  874. cache_config: CacheConfig) -> nn.Module:
  875. local_model_path = self._prepare_weights(model_config.model)
  876. gguf_weights_map = self._get_gguf_weights_map(model_config)
  877. # we can only know if tie word embeddings after mapping weights
  878. if "lm_head.weight" in get_gguf_extra_tensor_names(
  879. local_model_path, gguf_weights_map):
  880. model_config.hf_config.update({"tie_word_embeddings": True})
  881. with set_default_torch_dtype(model_config.dtype):
  882. with torch.device(device_config.device):
  883. model = _initialize_model(model_config, self.load_config,
  884. lora_config, cache_config)
  885. model.load_weights(
  886. self._get_weights_iterator(local_model_path, gguf_weights_map))
  887. return model
  888. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  889. """Get a model loader based on the load format."""
  890. if isinstance(load_config.load_format, type):
  891. return load_config.load_format(load_config)
  892. if load_config.load_format == LoadFormat.DUMMY:
  893. return DummyModelLoader(load_config)
  894. if load_config.load_format == LoadFormat.TENSORIZER:
  895. return TensorizerLoader(load_config)
  896. if load_config.load_format == LoadFormat.SHARDED_STATE:
  897. return ShardedStateLoader(load_config)
  898. if load_config.load_format == LoadFormat.BITSANDBYTES:
  899. return BitsAndBytesModelLoader(load_config)
  900. if load_config.load_format == LoadFormat.GGUF:
  901. return GGUFModelLoader(load_config)
  902. return DefaultModelLoader(load_config)