loader.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # ruff: noqa: SIM117
  2. import collections
  3. import copy
  4. import glob
  5. import os
  6. from abc import ABC, abstractmethod
  7. from typing import Any, Dict, Generator, List, Optional, Tuple, Type
  8. import huggingface_hub
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
  13. DeviceConfig, LoadConfig, LoadFormat,
  14. LoRAConfig, ModelConfig, ParallelConfig,
  15. SchedulerConfig, VisionLanguageConfig)
  16. from aphrodite.modeling.model_loader.tensorizer import (
  17. TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
  18. tensorizer_weights_iterator)
  19. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  20. set_default_torch_dtype)
  21. from aphrodite.modeling.model_loader.weight_utils import (
  22. download_safetensors_index_file_from_hf, download_weights_from_hf,
  23. filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
  24. get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
  25. pt_weights_iterator, safetensors_weights_iterator)
  26. from aphrodite.modeling.models.vlm_base import VisionLanguageModelBase
  27. from aphrodite.quantization.base_config import QuantizationConfig
  28. def _get_quantization_config(
  29. model_config: ModelConfig,
  30. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  31. """Get the quantization config."""
  32. if model_config.quantization is not None:
  33. quant_config = get_quant_config(model_config, load_config)
  34. capability = torch.cuda.get_device_capability()
  35. capability = capability[0] * 10 + capability[1]
  36. if capability < quant_config.get_min_capability():
  37. raise ValueError(
  38. f"The quantization method {model_config.quantization} is not "
  39. "supported for the current GPU. "
  40. f"Minimum capability: {quant_config.get_min_capability()}. "
  41. f"Current capability: {capability}.")
  42. supported_dtypes = quant_config.get_supported_act_dtypes()
  43. if model_config.dtype not in supported_dtypes:
  44. raise ValueError(
  45. f"{model_config.dtype} is not supported for quantization "
  46. f"method {model_config.quantization}. Supported dtypes: "
  47. f"{supported_dtypes}")
  48. return quant_config
  49. return None
  50. def _get_model_initialization_kwargs(
  51. model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
  52. vision_language_config: Optional[VisionLanguageConfig]
  53. ) -> Dict[str, Any]:
  54. """Get extra kwargs for model initialization."""
  55. extra_kwargs = {}
  56. if hasattr(model_class, "supported_lora_modules"):
  57. extra_kwargs["lora_config"] = lora_config
  58. elif lora_config:
  59. raise ValueError(
  60. f"Model {model_class.__name__} does not support LoRA, "
  61. "but LoRA is enabled. Support for this model may "
  62. "be added in the future. If this is important to you, "
  63. "please open an issue on github.")
  64. elif issubclass(model_class, VisionLanguageModelBase):
  65. if vision_language_config is None:
  66. raise ValueError("Provide `image_input_type` and other vision "
  67. "related configurations through LLM entrypoint "
  68. "or engine arguments.")
  69. extra_kwargs["vision_language_config"] = vision_language_config
  70. return extra_kwargs
  71. def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
  72. lora_config: Optional[LoRAConfig],
  73. vision_language_config: Optional[VisionLanguageConfig],
  74. cache_config: CacheConfig) -> nn.Module:
  75. """Initialize a model with the given configurations."""
  76. model_class = get_model_architecture(model_config)[0]
  77. quant_config = _get_quantization_config(model_config, load_config)
  78. return model_class(config=model_config.hf_config,
  79. cache_config=cache_config,
  80. quant_config=quant_config,
  81. **_get_model_initialization_kwargs(
  82. model_class, lora_config, vision_language_config))
  83. class BaseModelLoader(ABC):
  84. """Base class for model loaders."""
  85. def __init__(self, load_config: LoadConfig):
  86. self.load_config = load_config
  87. @abstractmethod
  88. def load_model(self, *, model_config: ModelConfig,
  89. device_config: DeviceConfig,
  90. lora_config: Optional[LoRAConfig],
  91. vision_language_config: Optional[VisionLanguageConfig],
  92. parallel_config: ParallelConfig,
  93. scheduler_config: SchedulerConfig,
  94. cache_config: CacheConfig) -> nn.Module:
  95. """Load a model with the given configurations."""
  96. ...
  97. class DefaultModelLoader(BaseModelLoader):
  98. """Model loader that can load different file types from disk."""
  99. def __init__(self, load_config: LoadConfig):
  100. super().__init__(load_config)
  101. if load_config.model_loader_extra_config:
  102. raise ValueError(f"Model loader extra config is not supported for "
  103. f"load format {load_config.load_format}")
  104. def _maybe_download_from_modelscope(
  105. self, model: str, revision: Optional[str]) -> Optional[str]:
  106. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  107. True.
  108. Returns the path to the downloaded model, or None if the model is not
  109. downloaded from ModelScope."""
  110. if APHRODITE_USE_MODELSCOPE:
  111. # download model from ModelScope hub,
  112. # lazy import so that modelscope is not required for normal use.
  113. # pylint: disable=C.
  114. from modelscope.hub.snapshot_download import snapshot_download
  115. if not os.path.exists(model):
  116. model_path = snapshot_download(
  117. model_id=model,
  118. cache_dir=self.load_config.download_dir,
  119. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  120. revision=revision,
  121. )
  122. else:
  123. model_path = model
  124. return model_path
  125. return None
  126. def _prepare_weights(self, model_name_or_path: str,
  127. revision: Optional[str],
  128. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  129. """Prepare weights for the model.
  130. If the model is not local, it will be downloaded."""
  131. model_name_or_path = self._maybe_download_from_modelscope(
  132. model_name_or_path, revision) or model_name_or_path
  133. is_local = os.path.isdir(model_name_or_path)
  134. load_format = self.load_config.load_format
  135. use_safetensors = False
  136. # Some quantized models use .pt files for storing the weights.
  137. if load_format == LoadFormat.AUTO:
  138. allow_patterns = ["*.safetensors", "*.bin"]
  139. elif load_format == LoadFormat.SAFETENSORS:
  140. use_safetensors = True
  141. allow_patterns = ["*.safetensors"]
  142. elif load_format == LoadFormat.PT:
  143. allow_patterns = ["*.pt"]
  144. elif load_format == LoadFormat.NPCACHE:
  145. allow_patterns = ["*.bin"]
  146. else:
  147. raise ValueError(f"Unknown load_format: {load_format}")
  148. if fall_back_to_pt:
  149. allow_patterns += ["*.pt"]
  150. if not is_local:
  151. hf_folder = download_weights_from_hf(model_name_or_path,
  152. self.load_config.download_dir,
  153. allow_patterns, revision)
  154. else:
  155. hf_folder = model_name_or_path
  156. hf_weights_files: List[str] = []
  157. for pattern in allow_patterns:
  158. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  159. if len(hf_weights_files) > 0:
  160. if pattern == "*.safetensors":
  161. use_safetensors = True
  162. break
  163. if use_safetensors:
  164. # For models like Mistral-7B-Instruct-v0.3
  165. # there are both sharded safetensors files and a consolidated
  166. # safetensors file. Using both breaks.
  167. # Here, we download the `model.safetensors.index.json` and filter
  168. # any files not found in the index.
  169. if not is_local:
  170. download_safetensors_index_file_from_hf(
  171. model_name_or_path, self.load_config.download_dir,
  172. revision)
  173. hf_weights_files = filter_duplicate_safetensors_files(
  174. hf_weights_files, hf_folder)
  175. else:
  176. hf_weights_files = filter_files_not_needed_for_inference(
  177. hf_weights_files)
  178. if len(hf_weights_files) == 0:
  179. raise RuntimeError(
  180. f"Cannot find any model weights with `{model_name_or_path}`")
  181. return hf_folder, hf_weights_files, use_safetensors
  182. def _get_weights_iterator(
  183. self, model_name_or_path: str, revision: Optional[str],
  184. fall_back_to_pt: bool
  185. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  186. """Get an iterator for the model weights based on the load format."""
  187. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  188. model_name_or_path, revision, fall_back_to_pt)
  189. if self.load_config.load_format == LoadFormat.NPCACHE:
  190. # Currently np_cache only support *.bin checkpoints
  191. assert use_safetensors is False
  192. return np_cache_weights_iterator(model_name_or_path,
  193. self.load_config.download_dir,
  194. hf_folder, hf_weights_files)
  195. if use_safetensors:
  196. return safetensors_weights_iterator(hf_weights_files)
  197. return pt_weights_iterator(hf_weights_files)
  198. def load_model(self, *, model_config: ModelConfig,
  199. device_config: DeviceConfig,
  200. lora_config: Optional[LoRAConfig],
  201. vision_language_config: Optional[VisionLanguageConfig],
  202. parallel_config: ParallelConfig,
  203. scheduler_config: SchedulerConfig,
  204. cache_config: CacheConfig) -> nn.Module:
  205. with set_default_torch_dtype(model_config.dtype):
  206. with torch.device(device_config.device):
  207. model = _initialize_model(model_config, self.load_config,
  208. lora_config, vision_language_config,
  209. cache_config)
  210. model.load_weights(
  211. self._get_weights_iterator(model_config.model,
  212. model_config.revision,
  213. fall_back_to_pt=getattr(
  214. model,
  215. "fall_back_to_pt_during_load",
  216. True)), )
  217. for _, module in model.named_modules():
  218. quant_method = getattr(module, "quant_method", None)
  219. if quant_method is not None:
  220. quant_method.process_weights_after_loading(module)
  221. # FIXME: Remove this after Mixtral is updated
  222. # to use quant_method.
  223. if hasattr(module, "process_weights_after_loading"):
  224. module.process_weights_after_loading()
  225. return model.eval()
  226. class DummyModelLoader(BaseModelLoader):
  227. """Model loader that will set model weights to random values."""
  228. def __init__(self, load_config: LoadConfig):
  229. super().__init__(load_config)
  230. if load_config.model_loader_extra_config:
  231. raise ValueError(f"Model loader extra config is not supported for "
  232. f"load format {load_config.load_format}")
  233. def load_model(self, *, model_config: ModelConfig,
  234. device_config: DeviceConfig,
  235. lora_config: Optional[LoRAConfig],
  236. vision_language_config: Optional[VisionLanguageConfig],
  237. parallel_config: ParallelConfig,
  238. scheduler_config: SchedulerConfig,
  239. cache_config: CacheConfig) -> nn.Module:
  240. with set_default_torch_dtype(model_config.dtype):
  241. with torch.device(device_config.device):
  242. model = _initialize_model(model_config, self.load_config,
  243. lora_config, vision_language_config,
  244. cache_config)
  245. # NOTE: For accurate performance evaluation, we assign
  246. # random values to the weights.
  247. initialize_dummy_weights(model)
  248. return model.eval()
  249. class TensorizerLoader(BaseModelLoader):
  250. """Model loader using CoreWeave's tensorizer library."""
  251. def __init__(self, load_config: LoadConfig):
  252. super().__init__(load_config)
  253. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  254. self.tensorizer_config = load_config.model_loader_extra_config
  255. else:
  256. self.tensorizer_config = TensorizerConfig(
  257. **load_config.model_loader_extra_config)
  258. def _verify_config(self, model_config: ModelConfig,
  259. parallel_config: ParallelConfig):
  260. self.tensorizer_config.verify_with_model_config(model_config)
  261. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  262. def _get_weights_iterator(
  263. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  264. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  265. return tensorizer_weights_iterator(tensorizer_args)
  266. def _load_model_serialized_cpu(
  267. self,
  268. model_config: ModelConfig,
  269. device_config: DeviceConfig,
  270. lora_config: Optional[LoRAConfig],
  271. vision_language_config: Optional[VisionLanguageConfig],
  272. cache_config: CacheConfig,
  273. ) -> nn.Module:
  274. """Load a serialized model with tensorizer to the CPU.
  275. This is only necessary when the model isn't Aphrodite-tensorized (see
  276. examples/tensorize_aphrodite_model.py) This should still be faster than
  277. default HuggingFace loading, but will be slower than loading a
  278. Aphrodite-tensorized model.
  279. """
  280. with set_default_torch_dtype(model_config.dtype):
  281. with torch.device(device_config.device):
  282. model = _initialize_model(model_config, self.load_config,
  283. lora_config, vision_language_config,
  284. cache_config)
  285. model.load_weights(self._get_weights_iterator())
  286. return model.eval()
  287. def _load_model_serialized(
  288. self,
  289. model_config: ModelConfig,
  290. device_config: DeviceConfig,
  291. lora_config: Optional[LoRAConfig],
  292. vision_language_config: Optional[VisionLanguageConfig],
  293. cache_config: CacheConfig,
  294. ) -> nn.Module:
  295. """Load a serialized model with tensorizer.
  296. Expects a Aphrodite-tensorized model. See the
  297. examples/tensorize_aphrodite_model.py example script
  298. for serializing Aphrodite models."""
  299. with set_default_torch_dtype(model_config.dtype):
  300. with torch.device(device_config.device):
  301. model_class = get_model_architecture(model_config)[0]
  302. quant_config = _get_quantization_config(
  303. model_config, self.load_config)
  304. extra_kwargs = _get_model_initialization_kwargs(
  305. model_class, lora_config, vision_language_config)
  306. extra_kwargs["quant_config"] = quant_config
  307. extra_kwargs["cache_config"] = cache_config
  308. tensorizer_config = copy.copy(self.tensorizer_config)
  309. tensorizer_config.model_class = model_class
  310. tensorizer_config.hf_config = model_config.hf_config
  311. tensorizer_config.dtype = model_config.dtype
  312. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  313. return model.eval()
  314. def load_model(self, *, model_config: ModelConfig,
  315. device_config: DeviceConfig,
  316. lora_config: Optional[LoRAConfig],
  317. vision_language_config: Optional[VisionLanguageConfig],
  318. parallel_config: ParallelConfig,
  319. scheduler_config: SchedulerConfig,
  320. cache_config: CacheConfig) -> nn.Module:
  321. self._verify_config(model_config, parallel_config)
  322. if is_aphrodite_tensorized(self.tensorizer_config):
  323. return self._load_model_serialized(model_config, device_config,
  324. lora_config,
  325. vision_language_config,
  326. cache_config)
  327. return self._load_model_serialized_cpu(model_config, device_config,
  328. lora_config,
  329. vision_language_config,
  330. cache_config)
  331. class ShardedStateLoader(BaseModelLoader):
  332. """
  333. Model loader that directly loads each worker's model state dict, which
  334. enables a fast load path for large tensor-parallel models where each worker
  335. only needs to read its own shard rather than the entire checkpoint. See
  336. `examples/save_sharded_states.py` for creating a sharded checkpoint.
  337. """
  338. DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
  339. def __init__(self, load_config: LoadConfig):
  340. super().__init__(load_config)
  341. extra_config = ({} if load_config.model_loader_extra_config is None
  342. else load_config.model_loader_extra_config.copy())
  343. self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
  344. if extra_config:
  345. raise ValueError(f"Unexpected extra config keys for load format "
  346. f"{load_config.load_format}: "
  347. f"{load_config.model_loader_extra_config.keys()}")
  348. @staticmethod
  349. def _filter_subtensors(
  350. tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
  351. """
  352. Filter out all tensors that share the same memory or a subset of the
  353. memory of another tensor.
  354. """
  355. same_storage_groups = collections.defaultdict(list)
  356. for key, tensor in tensors.items():
  357. if tensor.numel():
  358. ptr = tensor.untyped_storage().data_ptr()
  359. same_storage_groups[tensor.device, ptr].append((key, tensor))
  360. def get_end_ptr(tensor: torch.Tensor) -> int:
  361. return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  362. result = {}
  363. for group in same_storage_groups.values():
  364. for k, t in group:
  365. a, b = t.data_ptr(), get_end_ptr(t)
  366. for k2, t2 in group:
  367. if not t2.is_contiguous():
  368. continue
  369. a2, b2 = t2.data_ptr(), get_end_ptr(t2)
  370. if a < a2 or b2 < b:
  371. continue
  372. if a2 < a or b < b2 or not t.is_contiguous():
  373. break # t2 covers strictly more memory than t.
  374. if k2 < k:
  375. # Same tensors, keep the one with the smaller key.
  376. break
  377. else:
  378. result[k] = t
  379. return result
  380. def _prepare_weights(self, model_name_or_path: str,
  381. revision: Optional[str]):
  382. if os.path.isdir(model_name_or_path):
  383. return model_name_or_path
  384. else:
  385. allow_patterns = ["*.safetensors"]
  386. return download_weights_from_hf(model_name_or_path,
  387. self.load_config.download_dir,
  388. allow_patterns, revision)
  389. def load_model(self, *, model_config: ModelConfig,
  390. device_config: DeviceConfig,
  391. lora_config: Optional[LoRAConfig],
  392. vision_language_config: Optional[VisionLanguageConfig],
  393. parallel_config: ParallelConfig,
  394. scheduler_config: SchedulerConfig,
  395. cache_config: CacheConfig) -> nn.Module:
  396. from safetensors.torch import safe_open
  397. from aphrodite.distributed import get_tensor_model_parallel_rank
  398. local_model_path = self._prepare_weights(model_config.model,
  399. model_config.revision)
  400. with set_default_torch_dtype(model_config.dtype):
  401. with torch.device(device_config.device):
  402. model = _initialize_model(model_config, self.load_config,
  403. lora_config, vision_language_config,
  404. cache_config)
  405. rank = get_tensor_model_parallel_rank()
  406. pattern = os.path.join(
  407. local_model_path,
  408. self.pattern.format(rank=rank, part="*"),
  409. )
  410. filepaths = glob.glob(pattern)
  411. if not filepaths:
  412. # TODO: support un-sharded checkpoints too
  413. raise ValueError(
  414. f"Could not find checkpoint files '{pattern}', only "
  415. f"pre-sharded checkpoints are currently supported!")
  416. state_dict = self._filter_subtensors(model.state_dict())
  417. for path in filepaths:
  418. with safe_open(path, framework="pt") as f:
  419. for key in f.keys(): # noqa: SIM118
  420. tensor = f.get_tensor(key)
  421. # If loading with LoRA enabled, additional padding may
  422. # be added to certain parameters. We only load into a
  423. # narrowed view of the parameter data.
  424. param_data = state_dict[key].data
  425. param_shape = state_dict[key].shape
  426. for dim, size in enumerate(tensor.shape):
  427. if size < param_shape[dim]:
  428. param_data = param_data.narrow(dim, 0, size)
  429. if tensor.shape != param_shape:
  430. logger.warning("loading tensor of shape "
  431. f"{tensor.shape} into parameter "
  432. f"'{key}' of shape {param_shape}")
  433. param_data.copy_(tensor)
  434. state_dict.pop(key)
  435. if state_dict:
  436. raise ValueError(
  437. f"Missing keys {tuple(state_dict)} in loaded state!")
  438. return model.eval()
  439. @staticmethod
  440. def save_model(
  441. model: torch.nn.Module,
  442. path: str,
  443. pattern: Optional[str] = None,
  444. max_size: Optional[int] = None,
  445. ) -> None:
  446. from safetensors.torch import save_file
  447. from aphrodite.distributed import get_tensor_model_parallel_rank
  448. if pattern is None:
  449. pattern = ShardedStateLoader.DEFAULT_PATTERN
  450. rank = get_tensor_model_parallel_rank()
  451. part_idx = 0
  452. total_size = 0
  453. state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
  454. state_dict_part: Dict[str, torch.Tensor] = {}
  455. for key, tensor in state_dict.items():
  456. param_size = tensor.nelement() * tensor.element_size()
  457. if max_size is not None and total_size + param_size > max_size:
  458. filename = pattern.format(rank=rank, part=part_idx)
  459. save_file(
  460. state_dict_part,
  461. os.path.join(path, filename),
  462. )
  463. part_idx += 1
  464. total_size = 0
  465. state_dict_part = {}
  466. state_dict_part[key] = tensor
  467. total_size += param_size
  468. if len(state_dict_part) > 0:
  469. filename = pattern.format(rank=rank, part=part_idx)
  470. save_file(
  471. state_dict_part,
  472. os.path.join(path, filename),
  473. )
  474. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  475. """Get a model loader based on the load format."""
  476. if isinstance(load_config.load_format, type):
  477. return load_config.load_format(load_config)
  478. if load_config.load_format == LoadFormat.DUMMY:
  479. return DummyModelLoader(load_config)
  480. if load_config.load_format == LoadFormat.TENSORIZER:
  481. return TensorizerLoader(load_config)
  482. if load_config.load_format == LoadFormat.SHARDED_STATE:
  483. return ShardedStateLoader(load_config)
  484. return DefaultModelLoader(load_config)