loader.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210
  1. # ruff: noqa: SIM117
  2. import collections
  3. import copy
  4. import dataclasses
  5. import fnmatch
  6. import glob
  7. import json
  8. import math
  9. import os
  10. from abc import ABC, abstractmethod
  11. from contextlib import contextmanager
  12. from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
  13. Type, cast)
  14. import gguf
  15. import huggingface_hub
  16. import numpy as np
  17. import torch
  18. from huggingface_hub import HfApi, hf_hub_download
  19. from loguru import logger
  20. from torch import nn
  21. from transformers import AutoModelForCausalLM, PretrainedConfig
  22. from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
  23. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
  24. DeviceConfig, LoadConfig, LoadFormat,
  25. LoRAConfig, ModelConfig, MultiModalConfig,
  26. ParallelConfig, SchedulerConfig)
  27. from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
  28. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  29. get_tensor_model_parallel_world_size)
  30. from aphrodite.modeling.model_loader.tensorizer import (
  31. TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
  32. serialize_aphrodite_model, tensorizer_weights_iterator)
  33. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  34. set_default_torch_dtype)
  35. from aphrodite.modeling.model_loader.weight_utils import (
  36. download_safetensors_index_file_from_hf, download_weights_from_hf,
  37. filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
  38. get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
  39. initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
  40. safetensors_weights_iterator)
  41. from aphrodite.modeling.models.interfaces import (has_inner_state,
  42. supports_lora,
  43. supports_multimodal)
  44. from aphrodite.modeling.utils import set_weight_attrs
  45. from aphrodite.platforms import current_platform
  46. from aphrodite.quantization.base_config import QuantizationConfig
  47. @contextmanager
  48. def device_loading_context(module: torch.nn.Module,
  49. target_device: torch.device):
  50. if target_device.type == "cpu":
  51. # If target is CPU, no need to move anything
  52. yield module
  53. return
  54. original_device_states: Dict[str, torch.device] = {}
  55. # Store original device states and move parameters to GPU if they're on CPU
  56. for name, p in module.named_parameters():
  57. if p.device.type == "cpu":
  58. original_device_states[name] = p.device
  59. p.data = p.data.to(target_device)
  60. # Parameters already on target device are not touched
  61. try:
  62. yield module
  63. finally:
  64. # Restore parameters to their original devices, ignoring new parameters
  65. pin_memory = is_pin_memory_available()
  66. for name, p in module.named_parameters():
  67. if name in original_device_states:
  68. original_device: torch.device = original_device_states[name]
  69. if original_device.type == "cpu":
  70. # `torch.empty_like` does not support `pin_memory` argument
  71. cpu_data = torch.empty_strided(size=p.data.size(),
  72. stride=p.data.stride(),
  73. dtype=p.data.dtype,
  74. layout=p.data.layout,
  75. device="cpu",
  76. pin_memory=pin_memory)
  77. cpu_data.copy_(p.data)
  78. p.data = cpu_data
  79. else:
  80. p.data = p.data.to(original_device)
  81. # New parameters or parameters already on target device are untouched
  82. def _get_quantization_config(
  83. model_config: ModelConfig,
  84. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  85. """Get the quantization config."""
  86. if model_config.quantization is not None:
  87. quant_config = get_quant_config(model_config, load_config)
  88. capability = current_platform.get_device_capability() # type: ignore
  89. if capability is not None:
  90. capability = capability[0] * 10 + capability[1]
  91. if capability < quant_config.get_min_capability():
  92. raise ValueError(
  93. f"The quantization method {model_config.quantization} "
  94. "is not supported for the current GPU. "
  95. f"Minimum capability: {quant_config.get_min_capability()}. "
  96. f"Current capability: {capability}.")
  97. supported_dtypes = quant_config.get_supported_act_dtypes()
  98. if model_config.dtype not in supported_dtypes:
  99. raise ValueError(
  100. f"{model_config.dtype} is not supported for quantization "
  101. f"method {model_config.quantization}. Supported dtypes: "
  102. f"{supported_dtypes}")
  103. return quant_config
  104. return None
  105. def _get_model_initialization_kwargs(
  106. model_class: Type[nn.Module],
  107. lora_config: Optional[LoRAConfig],
  108. multimodal_config: Optional[MultiModalConfig],
  109. scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
  110. """Get extra kwargs for model initialization."""
  111. extra_kwargs: Dict[str, Any] = {}
  112. if supports_lora(model_class):
  113. # lora_config=None is used to disable LoRA
  114. extra_kwargs["lora_config"] = lora_config
  115. elif lora_config:
  116. raise ValueError(
  117. f"Model {model_class.__name__} does not support LoRA, "
  118. "but LoRA is enabled. Support for this model may "
  119. "be added in the future. If this is important to you, "
  120. "please open an issue on github.")
  121. if supports_multimodal(model_class):
  122. assert multimodal_config is not None
  123. extra_kwargs["multimodal_config"] = multimodal_config
  124. if has_inner_state(model_class) and scheduler_config:
  125. extra_kwargs["scheduler_config"] = scheduler_config
  126. return extra_kwargs
  127. def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
  128. cache_config: Optional[CacheConfig],
  129. quant_config: Optional[QuantizationConfig], *,
  130. lora_config: Optional[LoRAConfig],
  131. multimodal_config: Optional[MultiModalConfig],
  132. scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
  133. extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
  134. multimodal_config,
  135. scheduler_config)
  136. return model_class(config=hf_config,
  137. cache_config=cache_config,
  138. quant_config=quant_config,
  139. **extra_kwargs)
  140. def _initialize_model(
  141. model_config: ModelConfig,
  142. load_config: LoadConfig,
  143. lora_config: Optional[LoRAConfig],
  144. cache_config: CacheConfig,
  145. scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
  146. """Initialize a model with the given configurations."""
  147. model_class, _ = get_model_architecture(model_config)
  148. return build_model(
  149. model_class,
  150. model_config.hf_config,
  151. cache_config=cache_config,
  152. quant_config=_get_quantization_config(model_config, load_config),
  153. lora_config=lora_config,
  154. multimodal_config=model_config.multimodal_config,
  155. scheduler_config=scheduler_config,
  156. )
  157. class BaseModelLoader(ABC):
  158. """Base class for model loaders."""
  159. def __init__(self, load_config: LoadConfig):
  160. self.load_config = load_config
  161. @abstractmethod
  162. def download_model(self, model_config: ModelConfig) -> None:
  163. """Download a model so that it can be immediately loaded."""
  164. raise NotImplementedError
  165. @abstractmethod
  166. def load_model(self, *, model_config: ModelConfig,
  167. device_config: DeviceConfig,
  168. lora_config: Optional[LoRAConfig],
  169. parallel_config: ParallelConfig,
  170. scheduler_config: SchedulerConfig,
  171. cache_config: CacheConfig) -> nn.Module:
  172. """Load a model with the given configurations."""
  173. raise NotImplementedError
  174. class DefaultModelLoader(BaseModelLoader):
  175. """Model loader that can load different file types from disk."""
  176. @dataclasses.dataclass
  177. class Source:
  178. """A source for weights."""
  179. model_or_path: str
  180. """The model ID or path."""
  181. revision: Optional[str]
  182. """The optional model revision."""
  183. prefix: str = ""
  184. """A prefix to prepend to all weights."""
  185. fall_back_to_pt: bool = True
  186. """Whether .pt weights can be used."""
  187. def __init__(self, load_config: LoadConfig):
  188. super().__init__(load_config)
  189. if load_config.model_loader_extra_config:
  190. raise ValueError(f"Model loader extra config is not supported for "
  191. f"load format {load_config.load_format}")
  192. def _maybe_download_from_modelscope(
  193. self, model: str, revision: Optional[str]) -> Optional[str]:
  194. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  195. True.
  196. Returns the path to the downloaded model, or None if the model is not
  197. downloaded from ModelScope."""
  198. if APHRODITE_USE_MODELSCOPE:
  199. # download model from ModelScope hub,
  200. # lazy import so that modelscope is not required for normal use.
  201. # pylint: disable=C.
  202. from modelscope.hub.snapshot_download import snapshot_download
  203. if not os.path.exists(model):
  204. model_path = snapshot_download(
  205. model_id=model,
  206. cache_dir=self.load_config.download_dir,
  207. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  208. revision=revision,
  209. ignore_file_pattern=self.load_config.ignore_patterns,
  210. )
  211. else:
  212. model_path = model
  213. return model_path
  214. return None
  215. def _prepare_weights(self, model_name_or_path: str,
  216. revision: Optional[str],
  217. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  218. """Prepare weights for the model.
  219. If the model is not local, it will be downloaded."""
  220. model_name_or_path = self._maybe_download_from_modelscope(
  221. model_name_or_path, revision) or model_name_or_path
  222. is_local = os.path.isdir(model_name_or_path)
  223. load_format = self.load_config.load_format
  224. use_safetensors = False
  225. index_file = SAFE_WEIGHTS_INDEX_NAME
  226. # Some quantized models use .pt files for storing the weights.
  227. if load_format == LoadFormat.AUTO:
  228. allow_patterns = ["*.safetensors", "*.bin"]
  229. elif load_format == LoadFormat.SAFETENSORS:
  230. use_safetensors = True
  231. allow_patterns = ["*.safetensors"]
  232. elif load_format == LoadFormat.MISTRAL:
  233. use_safetensors = True
  234. allow_patterns = ["consolidated*.safetensors"]
  235. index_file = "consolidated.safetensors.index.json"
  236. elif load_format == LoadFormat.PT:
  237. allow_patterns = ["*.pt"]
  238. elif load_format == LoadFormat.NPCACHE:
  239. allow_patterns = ["*.bin"]
  240. else:
  241. raise ValueError(f"Unknown load_format: {load_format}")
  242. if fall_back_to_pt:
  243. allow_patterns += ["*.pt"]
  244. if not is_local:
  245. hf_folder = download_weights_from_hf(
  246. model_name_or_path,
  247. self.load_config.download_dir,
  248. allow_patterns,
  249. revision,
  250. ignore_patterns=self.load_config.ignore_patterns,
  251. )
  252. else:
  253. hf_folder = model_name_or_path
  254. hf_weights_files: List[str] = []
  255. for pattern in allow_patterns:
  256. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  257. if len(hf_weights_files) > 0:
  258. if pattern == "*.safetensors":
  259. use_safetensors = True
  260. break
  261. if use_safetensors:
  262. # For models like Mistral-7B-Instruct-v0.3
  263. # there are both sharded safetensors files and a consolidated
  264. # safetensors file. Using both breaks.
  265. # Here, we download the `model.safetensors.index.json` and filter
  266. # any files not found in the index.
  267. if not is_local:
  268. download_safetensors_index_file_from_hf(
  269. model_name_or_path, index_file,
  270. self.load_config.download_dir, revision)
  271. hf_weights_files = filter_duplicate_safetensors_files(
  272. hf_weights_files, hf_folder, index_file)
  273. else:
  274. hf_weights_files = filter_files_not_needed_for_inference(
  275. hf_weights_files)
  276. if len(hf_weights_files) == 0:
  277. raise RuntimeError(
  278. f"Cannot find any model weights with `{model_name_or_path}`")
  279. return hf_folder, hf_weights_files, use_safetensors
  280. def _get_weights_iterator(
  281. self, source: "Source"
  282. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], int]:
  283. """Get an iterator for the model weights based on the load format."""
  284. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  285. source.model_or_path, source.revision, source.fall_back_to_pt)
  286. est_weight_bytes = sum(os.path.getsize(f) for f in hf_weights_files)
  287. if self.load_config.load_format == LoadFormat.NPCACHE:
  288. # Currently np_cache only support *.bin checkpoints
  289. assert use_safetensors is False
  290. weights_iterator = np_cache_weights_iterator(
  291. source.model_or_path, self.load_config.download_dir, hf_folder,
  292. hf_weights_files)
  293. elif use_safetensors:
  294. weights_iterator = safetensors_weights_iterator(hf_weights_files)
  295. else:
  296. weights_iterator = pt_weights_iterator(hf_weights_files)
  297. if current_platform.is_tpu():
  298. # In PyTorch XLA, we should call `xm.mark_step` frequently so that
  299. # not too many ops are accumulated in the XLA program.
  300. import torch_xla.core.xla_model as xm
  301. def _xla_weights_iterator(iterator: Generator):
  302. for weights in iterator:
  303. yield weights
  304. xm.mark_step()
  305. weights_iterator = _xla_weights_iterator(weights_iterator)
  306. # Apply the prefix
  307. return ((source.prefix + name, tensor)
  308. for (name, tensor) in weights_iterator), est_weight_bytes
  309. def _get_all_weights(
  310. self,
  311. model_config: ModelConfig,
  312. model: nn.Module,
  313. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], int]:
  314. """Get all weights including primary and secondary sources."""
  315. primary_weights = DefaultModelLoader.Source(
  316. model_config.model,
  317. model_config.revision,
  318. prefix="",
  319. fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True))
  320. primary_iterator, primary_bytes = self._get_weights_iterator(
  321. primary_weights)
  322. total_bytes = primary_bytes
  323. # Collect all weight iterators and their sizes
  324. iterators = [primary_iterator]
  325. secondary_weights = cast(Iterable[DefaultModelLoader.Source],
  326. getattr(model, "secondary_weights", ()))
  327. for source in secondary_weights:
  328. iterator, bytes_size = self._get_weights_iterator(source)
  329. iterators.append(iterator)
  330. total_bytes += bytes_size
  331. # Chain all iterators together
  332. from itertools import chain
  333. return chain.from_iterable(iterators), total_bytes
  334. def download_model(self, model_config: ModelConfig) -> None:
  335. self._prepare_weights(model_config.model,
  336. model_config.revision,
  337. fall_back_to_pt=True)
  338. def load_model(self, *, model_config: ModelConfig,
  339. device_config: DeviceConfig,
  340. lora_config: Optional[LoRAConfig],
  341. parallel_config: ParallelConfig,
  342. scheduler_config: SchedulerConfig,
  343. cache_config: CacheConfig) -> nn.Module:
  344. target_device = torch.device(device_config.device)
  345. with set_default_torch_dtype(model_config.dtype):
  346. with target_device:
  347. model = _initialize_model(model_config, self.load_config,
  348. lora_config, cache_config,
  349. scheduler_config)
  350. weights_iter, total_bytes = self._get_all_weights(
  351. model_config, model)
  352. model.load_weights(tensor_progress_bar(weights_iter, total_bytes,
  353. "Loading model weights..."))
  354. for _, module in model.named_modules():
  355. quant_method = getattr(module, "quant_method", None)
  356. if quant_method is not None:
  357. # When quant methods need to process weights after loading
  358. # (for repacking, quantizing, etc), they expect parameters
  359. # to be on the global target device. This scope is for the
  360. # case where cpu offloading is used, where we will move the
  361. # parameters onto device for processing and back off after.
  362. with device_loading_context(module, target_device):
  363. quant_method.process_weights_after_loading(module)
  364. return model.eval()
  365. class DummyModelLoader(BaseModelLoader):
  366. """Model loader that will set model weights to random values."""
  367. def __init__(self, load_config: LoadConfig):
  368. super().__init__(load_config)
  369. if load_config.model_loader_extra_config:
  370. raise ValueError(f"Model loader extra config is not supported for "
  371. f"load format {load_config.load_format}")
  372. def download_model(self, model_config: ModelConfig) -> None:
  373. pass # Nothing to download
  374. def load_model(self, *, model_config: ModelConfig,
  375. device_config: DeviceConfig,
  376. lora_config: Optional[LoRAConfig],
  377. parallel_config: ParallelConfig,
  378. scheduler_config: SchedulerConfig,
  379. cache_config: CacheConfig) -> nn.Module:
  380. with set_default_torch_dtype(model_config.dtype):
  381. with torch.device(device_config.device):
  382. model = _initialize_model(model_config, self.load_config,
  383. lora_config, cache_config,
  384. scheduler_config)
  385. # NOTE: For accurate performance evaluation, we assign
  386. # random values to the weights.
  387. initialize_dummy_weights(model)
  388. return model.eval()
  389. class TensorizerLoader(BaseModelLoader):
  390. """Model loader using CoreWeave's tensorizer library."""
  391. def __init__(self, load_config: LoadConfig):
  392. super().__init__(load_config)
  393. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  394. self.tensorizer_config = load_config.model_loader_extra_config
  395. else:
  396. self.tensorizer_config = TensorizerConfig(
  397. **load_config.model_loader_extra_config)
  398. def _verify_config(self, model_config: ModelConfig,
  399. parallel_config: ParallelConfig):
  400. self.tensorizer_config.verify_with_model_config(model_config)
  401. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  402. def _get_weights_iterator(
  403. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  404. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  405. return tensorizer_weights_iterator(tensorizer_args)
  406. def _load_model_serialized_cpu(
  407. self,
  408. model_config: ModelConfig,
  409. device_config: DeviceConfig,
  410. lora_config: Optional[LoRAConfig],
  411. cache_config: CacheConfig,
  412. ) -> nn.Module:
  413. """Load a serialized model with tensorizer to the CPU.
  414. This is only necessary when the model isn't Aphrodite-tensorized (see
  415. examples/tensorize_aphrodite_model.py) This should still be faster than
  416. default HuggingFace loading, but will be slower than loading a
  417. Aphrodite-tensorized model.
  418. """
  419. with set_default_torch_dtype(model_config.dtype):
  420. with torch.device(device_config.device):
  421. model = _initialize_model(model_config, self.load_config,
  422. lora_config, cache_config)
  423. model.load_weights(self._get_weights_iterator())
  424. return model.eval()
  425. def _load_model_serialized(
  426. self,
  427. model_config: ModelConfig,
  428. device_config: DeviceConfig,
  429. lora_config: Optional[LoRAConfig],
  430. cache_config: CacheConfig,
  431. ) -> nn.Module:
  432. """Load a serialized model with tensorizer.
  433. Expects a Aphrodite-tensorized model. See the
  434. examples/tensorize_aphrodite_model.py example script
  435. for serializing Aphrodite models."""
  436. with set_default_torch_dtype(model_config.dtype):
  437. with torch.device(device_config.device):
  438. model_class = get_model_architecture(model_config)[0]
  439. quant_config = _get_quantization_config(
  440. model_config, self.load_config)
  441. extra_kwargs = _get_model_initialization_kwargs(
  442. model_class, lora_config, model_config.multimodal_config)
  443. extra_kwargs["quant_config"] = quant_config
  444. extra_kwargs["cache_config"] = cache_config
  445. tensorizer_config = copy.copy(self.tensorizer_config)
  446. tensorizer_config.model_class = model_class
  447. tensorizer_config.hf_config = model_config.hf_config
  448. tensorizer_config.dtype = model_config.dtype
  449. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  450. return model.eval()
  451. def download_model(self, model_config: ModelConfig) -> None:
  452. self.tensorizer_config.verify_with_model_config(model_config)
  453. with self.tensorizer_config.open_stream():
  454. pass
  455. def load_model(self, *, model_config: ModelConfig,
  456. device_config: DeviceConfig,
  457. lora_config: Optional[LoRAConfig],
  458. parallel_config: ParallelConfig,
  459. scheduler_config: SchedulerConfig,
  460. cache_config: CacheConfig) -> nn.Module:
  461. self._verify_config(model_config, parallel_config)
  462. if parallel_config.tensor_parallel_size > 1:
  463. from aphrodite.distributed import get_tensor_model_parallel_rank
  464. self.tensorizer_config.tensorizer_uri = \
  465. self.tensorizer_config.tensorizer_uri \
  466. % get_tensor_model_parallel_rank()
  467. if is_aphrodite_tensorized(self.tensorizer_config):
  468. return self._load_model_serialized(model_config, device_config,
  469. lora_config, cache_config)
  470. return self._load_model_serialized_cpu(model_config, device_config,
  471. lora_config, cache_config)
  472. @staticmethod
  473. def save_model(
  474. model: torch.nn.Module,
  475. tensorizer_config: TensorizerConfig,
  476. ) -> None:
  477. serialize_aphrodite_model(
  478. model=model,
  479. tensorizer_config=tensorizer_config,
  480. )
  481. class ShardedStateLoader(BaseModelLoader):
  482. """
  483. Model loader that directly loads each worker's model state dict, which
  484. enables a fast load path for large tensor-parallel models where each worker
  485. only needs to read its own shard rather than the entire checkpoint. See
  486. `examples/save_sharded_state.py` for creating a sharded checkpoint.
  487. """
  488. DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
  489. def __init__(self, load_config: LoadConfig):
  490. super().__init__(load_config)
  491. extra_config = ({} if load_config.model_loader_extra_config is None
  492. else load_config.model_loader_extra_config.copy())
  493. self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
  494. if extra_config:
  495. raise ValueError(f"Unexpected extra config keys for load format "
  496. f"{load_config.load_format}: "
  497. f"{load_config.model_loader_extra_config.keys()}")
  498. @staticmethod
  499. def _filter_subtensors(
  500. tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  501. """
  502. Filter out all tensors that share the same memory or a subset of the
  503. memory of another tensor.
  504. """
  505. same_storage_groups: Dict[Any, List[Tuple[
  506. str, torch.Tensor]]] = collections.defaultdict(list)
  507. for key, tensor in tensors.items():
  508. if tensor.numel():
  509. ptr = tensor.untyped_storage().data_ptr()
  510. same_storage_groups[tensor.device, ptr].append((key, tensor))
  511. def get_end_ptr(tensor: torch.Tensor) -> int:
  512. return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  513. result: Dict[str, torch.Tensor] = {}
  514. for group in same_storage_groups.values():
  515. for k, t in group:
  516. a, b = t.data_ptr(), get_end_ptr(t)
  517. for k2, t2 in group:
  518. if not t2.is_contiguous():
  519. continue
  520. a2, b2 = t2.data_ptr(), get_end_ptr(t2)
  521. if a < a2 or b2 < b:
  522. continue
  523. if a2 < a or b < b2 or not t.is_contiguous():
  524. break # t2 covers strictly more memory than t.
  525. if k2 < k:
  526. # Same tensors, keep the one with the smaller key.
  527. break
  528. else:
  529. result[k] = t
  530. return result
  531. def _prepare_weights(self, model_name_or_path: str,
  532. revision: Optional[str]):
  533. if os.path.isdir(model_name_or_path):
  534. return model_name_or_path
  535. else:
  536. allow_patterns = ["*.safetensors"]
  537. return download_weights_from_hf(
  538. model_name_or_path,
  539. self.load_config.download_dir,
  540. allow_patterns,
  541. revision,
  542. ignore_patterns=self.load_config.ignore_patterns,
  543. )
  544. def download_model(self, model_config: ModelConfig) -> None:
  545. self._prepare_weights(model_config.model, model_config.revision)
  546. def load_model(self, *, model_config: ModelConfig,
  547. device_config: DeviceConfig,
  548. lora_config: Optional[LoRAConfig],
  549. parallel_config: ParallelConfig,
  550. scheduler_config: SchedulerConfig,
  551. cache_config: CacheConfig) -> nn.Module:
  552. from safetensors.torch import safe_open
  553. from aphrodite.distributed import get_tensor_model_parallel_rank
  554. local_model_path = self._prepare_weights(model_config.model,
  555. model_config.revision)
  556. with set_default_torch_dtype(model_config.dtype):
  557. with torch.device(device_config.device):
  558. model = _initialize_model(model_config, self.load_config,
  559. lora_config, cache_config)
  560. for _, module in model.named_modules():
  561. quant_method = getattr(module, "quant_method", None)
  562. if quant_method is not None:
  563. quant_method.process_weights_after_loading(module)
  564. rank = get_tensor_model_parallel_rank()
  565. pattern = os.path.join(
  566. local_model_path,
  567. self.pattern.format(rank=rank, part="*"),
  568. )
  569. filepaths = glob.glob(pattern)
  570. if not filepaths:
  571. # TODO: support un-sharded checkpoints too
  572. raise ValueError(
  573. f"Could not find checkpoint files '{pattern}', only "
  574. f"pre-sharded checkpoints are currently supported!")
  575. state_dict = self._filter_subtensors(model.state_dict())
  576. for path in filepaths:
  577. with safe_open(path, framework="pt") as f:
  578. for key in f.keys(): # noqa: SIM118
  579. tensor = f.get_tensor(key)
  580. # If loading with LoRA enabled, additional padding may
  581. # be added to certain parameters. We only load into a
  582. # narrowed view of the parameter data.
  583. param_data = state_dict[key].data
  584. param_shape = state_dict[key].shape
  585. for dim, size in enumerate(tensor.shape):
  586. if size < param_shape[dim]:
  587. param_data = param_data.narrow(dim, 0, size)
  588. if tensor.shape != param_shape:
  589. logger.warning("loading tensor of shape "
  590. f"{tensor.shape} into parameter "
  591. f"'{key}' of shape {param_shape}")
  592. param_data.copy_(tensor)
  593. state_dict.pop(key)
  594. if state_dict:
  595. raise ValueError(
  596. f"Missing keys {tuple(state_dict)} in loaded state!")
  597. return model.eval()
  598. @staticmethod
  599. def save_model(
  600. model: torch.nn.Module,
  601. path: str,
  602. pattern: Optional[str] = None,
  603. max_size: Optional[int] = None,
  604. ) -> None:
  605. from safetensors.torch import save_file
  606. from aphrodite.distributed import get_tensor_model_parallel_rank
  607. if pattern is None:
  608. pattern = ShardedStateLoader.DEFAULT_PATTERN
  609. rank = get_tensor_model_parallel_rank()
  610. part_idx = 0
  611. total_size = 0
  612. state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
  613. state_dict_part: Dict[str, torch.Tensor] = {}
  614. for key, tensor in state_dict.items():
  615. param_size = tensor.nelement() * tensor.element_size()
  616. if max_size is not None and total_size + param_size > max_size:
  617. filename = pattern.format(rank=rank, part=part_idx)
  618. save_file(
  619. state_dict_part,
  620. os.path.join(path, filename),
  621. )
  622. part_idx += 1
  623. total_size = 0
  624. state_dict_part = {}
  625. state_dict_part[key] = tensor
  626. total_size += param_size
  627. if len(state_dict_part) > 0:
  628. filename = pattern.format(rank=rank, part=part_idx)
  629. save_file(
  630. state_dict_part,
  631. os.path.join(path, filename),
  632. )
  633. class BitsAndBytesModelLoader(BaseModelLoader):
  634. """Model loader to load model weights with BitAndBytes quantization."""
  635. # TODO: these module names are for Llama only,
  636. # change so that it works with other models as well
  637. default_target_modules = [
  638. "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
  639. "o_proj"
  640. ]
  641. possible_config_file_names = ["adapter_config.json"]
  642. def __init__(self, load_config: LoadConfig):
  643. super().__init__(load_config)
  644. # we don't need to quantize the whole model, only the target modules
  645. # that are specified in the adapter config file. If the adapter config
  646. # file is not provided, we will quantize the default modules.
  647. if (not load_config.model_loader_extra_config
  648. or "qlora_adapter_name_or_path"
  649. not in load_config.model_loader_extra_config):
  650. self.target_modules = self.default_target_modules
  651. return
  652. qlora_adapter = load_config.model_loader_extra_config[
  653. "qlora_adapter_name_or_path"]
  654. config_file_path = self._get_config_file(qlora_adapter)
  655. with open(config_file_path, "r") as f:
  656. config = json.load(f)
  657. self.target_modules = config["target_modules"]
  658. def _get_config_file(self, qlora_adapter: str) -> str:
  659. is_local = os.path.isdir(qlora_adapter)
  660. config_file_path = None
  661. if is_local:
  662. for file in self.possible_config_file_names:
  663. config_file_path = os.path.join(qlora_adapter, file)
  664. if os.path.exists(config_file_path):
  665. break
  666. else:
  667. hf_api = HfApi()
  668. repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
  669. for file in self.possible_config_file_names:
  670. if file in repo_files:
  671. config_file_path = hf_hub_download(repo_id=qlora_adapter,
  672. filename=file)
  673. break
  674. if not config_file_path:
  675. raise ValueError(
  676. f"Cannot find adapter config file in {qlora_adapter}")
  677. return config_file_path
  678. def _get_weight_files(
  679. self,
  680. model_name_or_path: str,
  681. allowed_patterns: List[str],
  682. revision: Optional[str] = None) -> Tuple[List[str], str]:
  683. """Retrieve weight files. Download the files if necessary.
  684. Return the weight files and the file pattern."""
  685. is_local = os.path.isdir(model_name_or_path)
  686. if is_local:
  687. for pattern in allowed_patterns:
  688. weight_files = glob.glob(
  689. os.path.join(model_name_or_path, pattern))
  690. if weight_files:
  691. return weight_files, pattern
  692. else:
  693. hf_api = HfApi()
  694. repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
  695. for pattern in allowed_patterns:
  696. matching_files = fnmatch.filter(repo_files, pattern)
  697. if matching_files:
  698. hf_folder = download_weights_from_hf(
  699. model_name_or_path,
  700. self.load_config.download_dir,
  701. [pattern],
  702. revision,
  703. ignore_patterns=self.load_config.ignore_patterns,
  704. )
  705. return glob.glob(os.path.join(hf_folder, pattern)), pattern
  706. raise RuntimeError(
  707. f"No model weights found in: `{model_name_or_path}`")
  708. def _prepare_weights(self, model_name_or_path: str,
  709. revision: Optional[str]) -> Tuple[List[str], bool]:
  710. """Prepare weight files for the model."""
  711. allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
  712. hf_weights_files, matched_pattern = self._get_weight_files(
  713. model_name_or_path, allowed_patterns, revision)
  714. if matched_pattern != "*.safetensors":
  715. hf_weights_files = filter_files_not_needed_for_inference(
  716. hf_weights_files)
  717. if len(hf_weights_files) == 0:
  718. raise RuntimeError(
  719. f"Cannot find any model weights with `{model_name_or_path}`")
  720. return hf_weights_files, matched_pattern == "*.safetensors"
  721. def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
  722. if use_safetensors:
  723. return safetensors_weights_iterator(hf_weights_files)
  724. else:
  725. return pt_weights_iterator(hf_weights_files)
  726. def _get_quantized_weights_iterator(
  727. self,
  728. model_name_or_path: str,
  729. revision: Optional[str],
  730. pre_quant: bool,
  731. load_8bit: bool,
  732. ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
  733. Any]]:
  734. """Get an iterator to the model weights with bitsandbytes quantization,
  735. as well as the quantization state dictionary."""
  736. # only load the bitsandbytes module when needed
  737. try:
  738. import bitsandbytes
  739. if bitsandbytes.__version__ < "0.45.1":
  740. raise ImportError("bitsandbytes version is wrong. Please "
  741. "install bitsandbytes>=0.45.1.")
  742. except ImportError as err:
  743. raise ImportError("Please install bitsandbytes>=0.45.1 via "
  744. "`pip install bitsandbytes>=0.45.1` to use "
  745. "bitsandbytes quantizer.") from err
  746. hf_weights_files, use_safetensors = self._prepare_weights(
  747. model_name_or_path, revision)
  748. quant_state_dict: Dict[str, Any] = {}
  749. if pre_quant:
  750. if load_8bit:
  751. return self._quantized_8bit_generator(
  752. hf_weights_files, use_safetensors,
  753. quant_state_dict), quant_state_dict
  754. else:
  755. return self._quantized_4bit_generator(
  756. hf_weights_files, use_safetensors,
  757. quant_state_dict), quant_state_dict
  758. return self._unquantized_generator(hf_weights_files, use_safetensors,
  759. quant_state_dict), quant_state_dict
  760. def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
  761. quant_state_dict) -> Generator:
  762. for weight_name, weight_tensor in self._hf_weight_iter(
  763. hf_weights_files, use_safetensors):
  764. if not weight_name.lower().endswith(".scb"):
  765. continue
  766. weight_key = weight_name.lower().replace(".scb", ".qweight")
  767. quant_state_dict[weight_key] = weight_tensor
  768. for weight_name, weight_tensor in self._hf_weight_iter(
  769. hf_weights_files, use_safetensors):
  770. if not weight_name.endswith(".weight"):
  771. continue
  772. qweight_name = weight_name.replace(".weight", ".qweight")
  773. if qweight_name in quant_state_dict:
  774. set_weight_attrs(weight_tensor, {"load_in_8bit": True})
  775. yield qweight_name, weight_tensor
  776. else:
  777. yield weight_name, weight_tensor
  778. def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
  779. quant_state_dict) -> Generator:
  780. from bitsandbytes.functional import QuantState
  781. # First iterate over all quant state weights
  782. weight_iterator = self._hf_weight_iter(hf_weights_files,
  783. use_safetensors)
  784. temp_state_dict = {}
  785. for weight_name, weight_tensor in weight_iterator:
  786. if weight_name.endswith(".weight"):
  787. continue
  788. # bitsandbytes library requires
  789. # weight.quant_state.bitsandbytes__* in CPU
  790. if "quant_state.bitsandbytes" in weight_name:
  791. temp_state_dict[weight_name] = weight_tensor.cpu().data
  792. else:
  793. temp_state_dict[weight_name] = weight_tensor
  794. # Closure to parse quant_state for each prequant weight
  795. def _parse_quant_state(param_name: str,
  796. temp_state_dict: Dict) -> QuantState:
  797. quant_state = {}
  798. for k in temp_state_dict:
  799. if param_name + "." in k:
  800. quant_state[k] = temp_state_dict[k]
  801. return QuantState.from_dict(quant_state, device="cuda")
  802. # Second iterate over all prequant and normal weights
  803. # pre quantized weights would have a quant_state
  804. for weight_name, weight_tensor in self._hf_weight_iter(
  805. hf_weights_files, use_safetensors):
  806. # Filter out all weights whose suffix is not ".weight"
  807. if not weight_name.endswith(".weight"):
  808. continue
  809. if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
  810. in temp_state_dict) or \
  811. (f"{weight_name}.quant_state.bitsandbytes__fp4" \
  812. in temp_state_dict):
  813. quant_state = _parse_quant_state(weight_name, temp_state_dict)
  814. weight_name = weight_name.replace(".weight", ".qweight")
  815. quant_state_dict[weight_name] = quant_state
  816. yield weight_name.replace(".weight", ".qweight"), weight_tensor
  817. else:
  818. yield weight_name, weight_tensor
  819. def _unquantized_generator(self, hf_weights_files, use_safetensors,
  820. quant_state_dict) -> Generator:
  821. from bitsandbytes.functional import quantize_4bit
  822. tp_size = get_tensor_model_parallel_world_size()
  823. tp_rank = get_tensor_model_parallel_rank()
  824. for weight_name, weight_tensor in self._hf_weight_iter(
  825. hf_weights_files, use_safetensors):
  826. if any(target_module in weight_name
  827. for target_module in self.target_modules):
  828. weight_name = weight_name.replace(".weight", ".qweight")
  829. # weight partitions of different modules occur at
  830. # different dimensions
  831. # TODO: these module names are for Llama only,
  832. # change so that it works with other models as well
  833. if 'down_proj' in weight_name or 'o_proj' in weight_name:
  834. total_size = weight_tensor.size(-1)
  835. start_index = total_size // tp_size * tp_rank
  836. end_index = total_size // tp_size * (tp_rank + 1)
  837. weight_sub_tensor = weight_tensor[...,
  838. start_index:end_index]
  839. else:
  840. total_size = weight_tensor.size(0)
  841. start_index = total_size // tp_size * tp_rank
  842. end_index = total_size // tp_size * (tp_rank + 1)
  843. weight_sub_tensor = weight_tensor[start_index:end_index,
  844. ...]
  845. # bitsandbytes requires data in GPU
  846. if weight_sub_tensor.is_cuda:
  847. loaded_weight = weight_sub_tensor
  848. else:
  849. loaded_weight = weight_sub_tensor.cuda()
  850. # remove the following after the issue is fixed:
  851. # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
  852. if loaded_weight.is_contiguous() is False:
  853. loaded_weight = loaded_weight.contiguous()
  854. with set_default_torch_dtype(torch.float32):
  855. processed_weight, quant_state = quantize_4bit(
  856. loaded_weight,
  857. compress_statistics=True,
  858. quant_type="nf4")
  859. quant_state_dict[weight_name] = quant_state
  860. else:
  861. processed_weight = weight_tensor
  862. yield weight_name, processed_weight
  863. def _load_weights(self, model_config: ModelConfig,
  864. model: nn.Module) -> None:
  865. if not hasattr(model, 'load_weights'):
  866. raise AttributeError(
  867. "The required method 'load_weights' is not defined in class"
  868. f" {type(self).__name__}.")
  869. if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
  870. raise AttributeError(
  871. f"Model {type(self).__name__} does not support BitsAndBytes "
  872. "quantization yet.")
  873. logger.info("Loading weights with BitsAndBytes quantization. "
  874. "This May take a while ...")
  875. quant_config = getattr(model_config.hf_config, "quantization_config",
  876. None)
  877. pre_quant = False
  878. if quant_config is not None:
  879. quant_method = quant_config.get('quant_method')
  880. if quant_method == "bitsandbytes":
  881. pre_quant = True
  882. else:
  883. raise ValueError(
  884. f"BitsAndBytes loader does not support {quant_method} "
  885. "quantization")
  886. # The quant_states in pre_quantized models cannot work with a split
  887. # weight tensor. So TP does not work with pre_quantized bnb models.
  888. if pre_quant and get_tensor_model_parallel_world_size() > 1:
  889. raise ValueError(
  890. "Prequant BitsAndBytes models with TP is not supported."
  891. "Please try with PP.")
  892. load_8bit = False
  893. if pre_quant:
  894. load_8bit = quant_config.get('load_in_8bit', False)
  895. qweight_iterator, quant_state_dict = \
  896. self._get_quantized_weights_iterator(
  897. model_config.model, model_config.revision, pre_quant, load_8bit)
  898. model.load_weights(qweight_iterator)
  899. torch.cuda.empty_cache()
  900. param_dict = dict(model.named_parameters())
  901. stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
  902. for quant_param_name in quant_state_dict:
  903. non_stacked_param_name = quant_param_name
  904. shard_index = 0
  905. for shard_name, (
  906. weight_name, index
  907. ) in model.bitsandbytes_stacked_params_mapping.items():
  908. if shard_name in quant_param_name:
  909. shard_index = index
  910. quant_param_name = quant_param_name.replace(
  911. shard_name, weight_name)
  912. break
  913. if quant_param_name not in param_dict:
  914. raise ValueError(
  915. f"Parameter {quant_param_name} not found in the model.")
  916. if quant_param_name not in stacked_quant_state_dict:
  917. stacked_quant_state_dict[quant_param_name] = {}
  918. stacked_quant_state_dict[quant_param_name][shard_index] = (
  919. quant_state_dict[non_stacked_param_name])
  920. # save quant_states and offsets as the attributes of the parameters
  921. for param_name, param in param_dict.items():
  922. if param_name in stacked_quant_state_dict:
  923. quant_states = stacked_quant_state_dict[param_name]
  924. set_weight_attrs(param, {"bnb_quant_state": quant_states})
  925. pack_ratio = getattr(param, "pack_factor", -1)
  926. if pack_ratio == -1:
  927. raise ValueError(
  928. f"pack_factor not set for parameter {param_name}.")
  929. num_elements = [0] * len(quant_states)
  930. for seq, quant_state in quant_states.items():
  931. num_elements[seq] = math.prod(
  932. quant_state.shape) // pack_ratio
  933. offsets = np.concatenate(([0], np.cumsum(num_elements)))
  934. set_weight_attrs(param, {"bnb_shard_offsets": offsets})
  935. if load_8bit:
  936. set_weight_attrs(
  937. param, {"matmul_state": [None] * len(quant_states)})
  938. def download_model(self, model_config: ModelConfig) -> None:
  939. self._prepare_weights(model_config.model, model_config.revision)
  940. def load_model(self, *, model_config: ModelConfig,
  941. device_config: DeviceConfig,
  942. lora_config: Optional[LoRAConfig],
  943. parallel_config: ParallelConfig,
  944. scheduler_config: SchedulerConfig,
  945. cache_config: CacheConfig) -> nn.Module:
  946. with set_default_torch_dtype(model_config.dtype):
  947. with torch.device(device_config.device):
  948. model = _initialize_model(model_config, self.load_config,
  949. lora_config, cache_config)
  950. self._load_weights(model_config, model)
  951. return model.eval()
  952. class GGUFModelLoader(BaseModelLoader):
  953. """
  954. Model loader that can load GGUF files. This is useful for loading models
  955. that are quantized with GGUF and saved in the GGUF format. This loader
  956. supports loading both full models and sharded models.
  957. """
  958. def __init__(self, load_config: LoadConfig):
  959. super().__init__(load_config)
  960. if load_config.model_loader_extra_config:
  961. raise ValueError(f"Model loader extra config is not supported for "
  962. f"load format {load_config.load_format}")
  963. def _prepare_weights(self, model_name_or_path: str):
  964. if os.path.isfile(model_name_or_path):
  965. return model_name_or_path
  966. else:
  967. raise ValueError(f"{model_name_or_path} is not a file.")
  968. def _get_gguf_weights_map(self, model_config: ModelConfig):
  969. """
  970. GGUF uses this naming convention for their tensors from HF checkpoint:
  971. `blk.N.BB.weight` and `blk.N.BB.bias`
  972. where N signifies the block number of a layer, and BB signifies the
  973. attention/mlp layer components.
  974. See "Standardized tensor names" in
  975. https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
  976. """
  977. config = model_config.hf_config
  978. model_type = config.model_type
  979. # hack: ggufs have a different name than transformers
  980. if model_type == "cohere":
  981. model_type = "command-r"
  982. arch = None
  983. for key, value in gguf.MODEL_ARCH_NAMES.items():
  984. if value == model_type:
  985. arch = key
  986. break
  987. if arch is None:
  988. raise RuntimeError(f"Unknown gguf model_type: {model_type}")
  989. num_layers = config.num_hidden_layers
  990. name_map = gguf.get_tensor_name_map(arch, num_layers)
  991. with torch.device("meta"):
  992. dummy_model = AutoModelForCausalLM.from_config(config)
  993. state_dict = dummy_model.state_dict()
  994. gguf_to_hf_name_map = {}
  995. for hf_name in state_dict:
  996. name, suffix = hf_name.rsplit(".", 1)
  997. gguf_name = name_map.get_name(name)
  998. gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
  999. return gguf_to_hf_name_map
  1000. def _get_weights_iterator(
  1001. self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
  1002. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  1003. return gguf_quant_weights_iterator(model_name_or_path,
  1004. gguf_to_hf_name_map)
  1005. def download_model(self, model_config: ModelConfig) -> None:
  1006. self._prepare_weights(model_config.model)
  1007. def load_model(self, *, model_config: ModelConfig,
  1008. device_config: DeviceConfig,
  1009. lora_config: Optional[LoRAConfig],
  1010. parallel_config: ParallelConfig,
  1011. scheduler_config: SchedulerConfig,
  1012. cache_config: CacheConfig) -> nn.Module:
  1013. local_model_path = self._prepare_weights(model_config.model)
  1014. gguf_weights_map = self._get_gguf_weights_map(model_config)
  1015. # we can only know if tie word embeddings after mapping weights
  1016. if "lm_head.weight" in get_gguf_extra_tensor_names(
  1017. local_model_path, gguf_weights_map):
  1018. model_config.hf_config.update({"tie_word_embeddings": True})
  1019. with set_default_torch_dtype(model_config.dtype):
  1020. with torch.device(device_config.device):
  1021. model = _initialize_model(model_config, self.load_config,
  1022. lora_config, cache_config)
  1023. model.load_weights(
  1024. self._get_weights_iterator(local_model_path, gguf_weights_map))
  1025. return model
  1026. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  1027. """Get a model loader based on the load format."""
  1028. if isinstance(load_config.load_format, type):
  1029. return load_config.load_format(load_config)
  1030. if load_config.load_format == LoadFormat.DUMMY:
  1031. return DummyModelLoader(load_config)
  1032. if load_config.load_format == LoadFormat.TENSORIZER:
  1033. return TensorizerLoader(load_config)
  1034. if load_config.load_format == LoadFormat.SHARDED_STATE:
  1035. return ShardedStateLoader(load_config)
  1036. if load_config.load_format == LoadFormat.BITSANDBYTES:
  1037. return BitsAndBytesModelLoader(load_config)
  1038. if load_config.load_format == LoadFormat.GGUF:
  1039. return GGUFModelLoader(load_config)
  1040. return DefaultModelLoader(load_config)