1
0

loader.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # ruff: noqa: SIM117
  2. import copy
  3. import gc
  4. import glob
  5. import os
  6. from abc import ABC, abstractmethod
  7. from contextlib import nullcontext
  8. from typing import (Any, Dict, Generator, List, Optional, Tuple, Type)
  9. import torch
  10. from torch import nn
  11. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, DeviceConfig,
  12. LoadConfig, LoadFormat, LoRAConfig,
  13. ModelConfig, ParallelConfig,
  14. SchedulerConfig, VisionLanguageConfig)
  15. from aphrodite.modeling.model_loader.tensorizer import (
  16. TensorizerConfig, is_aphrodite_serialized_tensorizer, load_with_tensorizer,
  17. tensorizer_weights_iterator)
  18. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  19. set_default_torch_dtype)
  20. from aphrodite.modeling.model_loader.weight_utils import (
  21. download_weights_from_hf, filter_files_not_needed_for_inference,
  22. get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
  23. pt_weights_iterator, safetensors_weights_iterator)
  24. from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
  25. from aphrodite.quantization.base_config import QuantizationConfig
  26. from aphrodite.quantization.bitsandbytes import (BNBLinearMethod,
  27. replace_quant_params)
  28. _VISION_MODEL_CLASSES = [
  29. LlavaForConditionalGeneration,
  30. ]
  31. def _get_quantization_config(
  32. model_config: ModelConfig,
  33. load_config: LoadConfig) -> Optional[QuantizationConfig]:
  34. """Get the quantization config."""
  35. if model_config.quantization is not None:
  36. quant_config = get_quant_config(model_config, load_config)
  37. capability = torch.cuda.get_device_capability()
  38. capability = capability[0] * 10 + capability[1]
  39. if capability < quant_config.get_min_capability():
  40. raise ValueError(
  41. f"The quantization method {model_config.quantization} is not "
  42. "supported for the current GPU. "
  43. f"Minimum capability: {quant_config.get_min_capability()}. "
  44. f"Current capability: {capability}.")
  45. supported_dtypes = quant_config.get_supported_act_dtypes()
  46. if model_config.dtype not in supported_dtypes:
  47. raise ValueError(
  48. f"{model_config.dtype} is not supported for quantization "
  49. f"method {model_config.quantization}. Supported dtypes: "
  50. f"{supported_dtypes}")
  51. return quant_config
  52. return None
  53. def _get_model_initialization_kwargs(
  54. model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
  55. vision_language_config: Optional[VisionLanguageConfig]
  56. ) -> Dict[str, Any]:
  57. """Get extra kwargs for model initialization."""
  58. extra_kwargs = {}
  59. if hasattr(model_class, "supported_lora_modules"):
  60. extra_kwargs["lora_config"] = lora_config
  61. elif lora_config:
  62. raise ValueError(
  63. f"Model {model_class.__name__} does not support LoRA, "
  64. "but LoRA is enabled. Support for this model may "
  65. "be added in the future. If this is important to you, "
  66. "please open an issue on github.")
  67. elif model_class in _VISION_MODEL_CLASSES:
  68. extra_kwargs["vision_language_config"] = vision_language_config
  69. return extra_kwargs
  70. def _initialize_model(
  71. model_config: ModelConfig, load_config: LoadConfig,
  72. lora_config: Optional[LoRAConfig],
  73. vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
  74. """Initialize a model with the given configurations."""
  75. model_class = get_model_architecture(model_config)[0]
  76. quant_config = _get_quantization_config(model_config, load_config)
  77. return model_class(config=model_config.hf_config,
  78. quant_config=quant_config,
  79. **_get_model_initialization_kwargs(
  80. model_class, lora_config, vision_language_config))
  81. class BaseModelLoader(ABC):
  82. """Base class for model loaders."""
  83. def __init__(self, load_config: LoadConfig):
  84. self.load_config = load_config
  85. @abstractmethod
  86. def load_model(self, *, model_config: ModelConfig,
  87. device_config: DeviceConfig,
  88. lora_config: Optional[LoRAConfig],
  89. vision_language_config: Optional[VisionLanguageConfig],
  90. parallel_config: ParallelConfig,
  91. scheduler_config: SchedulerConfig) -> nn.Module:
  92. """Load a model with the given configurations."""
  93. ...
  94. class DefaultModelLoader(BaseModelLoader):
  95. """Model loader that can load different file types from disk."""
  96. def __init__(self, load_config: LoadConfig):
  97. super().__init__(load_config)
  98. if load_config.model_loader_extra_config:
  99. raise ValueError(f"Model loader extra config is not supported for "
  100. f"load format {load_config.load_format}")
  101. def _maybe_download_from_modelscope(
  102. self, model: str, revision: Optional[str]) -> Optional[str]:
  103. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  104. True.
  105. Returns the path to the downloaded model, or None if the model is not
  106. downloaded from ModelScope."""
  107. if APHRODITE_USE_MODELSCOPE:
  108. # download model from ModelScope hub,
  109. # lazy import so that modelscope is not required for normal use.
  110. # pylint: disable=C.
  111. from modelscope.hub.snapshot_download import snapshot_download
  112. if not os.path.exists(model):
  113. model_path = snapshot_download(
  114. model_id=model,
  115. cache_dir=self.load_config.download_dir,
  116. revision=revision)
  117. else:
  118. model_path = model
  119. return model_path
  120. return None
  121. def _prepare_weights(self, model_name_or_path: str,
  122. revision: Optional[str],
  123. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  124. """Prepare weights for the model.
  125. If the model is not local, it will be downloaded."""
  126. model_name_or_path = self._maybe_download_from_modelscope(
  127. model_name_or_path, revision) or model_name_or_path
  128. is_local = os.path.isdir(model_name_or_path)
  129. load_format = self.load_config.load_format
  130. use_safetensors = False
  131. # Some quantized models use .pt files for storing the weights.
  132. if load_format == LoadFormat.AUTO:
  133. allow_patterns = ["*.safetensors", "*.bin"]
  134. elif load_format == LoadFormat.SAFETENSORS:
  135. use_safetensors = True
  136. allow_patterns = ["*.safetensors"]
  137. elif load_format == LoadFormat.PT:
  138. allow_patterns = ["*.pt"]
  139. elif load_format == LoadFormat.NPCACHE:
  140. allow_patterns = ["*.bin"]
  141. else:
  142. raise ValueError(f"Unknown load_format: {load_format}")
  143. if fall_back_to_pt:
  144. allow_patterns += ["*.pt"]
  145. if not is_local:
  146. hf_folder = download_weights_from_hf(model_name_or_path,
  147. self.load_config.download_dir,
  148. allow_patterns, revision)
  149. else:
  150. hf_folder = model_name_or_path
  151. hf_weights_files: List[str] = []
  152. for pattern in allow_patterns:
  153. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  154. if len(hf_weights_files) > 0:
  155. if pattern == "*.safetensors":
  156. use_safetensors = True
  157. break
  158. if not use_safetensors:
  159. hf_weights_files = filter_files_not_needed_for_inference(
  160. hf_weights_files)
  161. if len(hf_weights_files) == 0:
  162. raise RuntimeError(
  163. f"Cannot find any model weights with `{model_name_or_path}`")
  164. return hf_folder, hf_weights_files, use_safetensors
  165. def _get_weights_iterator(
  166. self, model_name_or_path: str, revision: Optional[str],
  167. fall_back_to_pt: bool
  168. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  169. """Get an iterator for the model weights based on the load format."""
  170. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  171. model_name_or_path, revision, fall_back_to_pt)
  172. if self.load_config.load_format == LoadFormat.NPCACHE:
  173. # Currently np_cache only support *.bin checkpoints
  174. assert use_safetensors is False
  175. return np_cache_weights_iterator(model_name_or_path,
  176. self.load_config.download_dir,
  177. hf_folder, hf_weights_files)
  178. if use_safetensors:
  179. return safetensors_weights_iterator(hf_weights_files)
  180. return pt_weights_iterator(hf_weights_files)
  181. def load_model(self, *, model_config: ModelConfig,
  182. device_config: DeviceConfig,
  183. lora_config: Optional[LoRAConfig],
  184. vision_language_config: Optional[VisionLanguageConfig],
  185. parallel_config: ParallelConfig,
  186. scheduler_config: SchedulerConfig) -> nn.Module:
  187. with set_default_torch_dtype(model_config.dtype):
  188. linear_method = _get_quantization_config(model_config,
  189. self.load_config)
  190. context = torch.device(device_config.device) if not (
  191. isinstance(linear_method, BNBLinearMethod)
  192. and linear_method.quant_config.from_float) else nullcontext()
  193. with context:
  194. model = _initialize_model(model_config, self.load_config,
  195. lora_config, vision_language_config)
  196. model.load_weights(
  197. self._get_weights_iterator(model_config.model,
  198. model_config.revision,
  199. fall_back_to_pt=getattr(
  200. model,
  201. "fall_back_to_pt_during_load",
  202. True)), )
  203. for _, module in model.named_modules():
  204. quant_method = getattr(module, "quant_method", None)
  205. if quant_method is not None:
  206. quant_method.process_weights_after_loading(module)
  207. # FIXME: Remove this after Mixtral is updated
  208. # to use quant_method.
  209. if hasattr(module, "process_weights_after_loading"):
  210. module.process_weights_after_loading()
  211. if isinstance(linear_method, BNBLinearMethod):
  212. replace_quant_params(
  213. model,
  214. quant_config=linear_method.quant_config,
  215. modules_to_not_convert="lm_head",
  216. )
  217. torch.cuda.synchronize()
  218. if linear_method.quant_config.from_float:
  219. model = model.cuda()
  220. gc.collect()
  221. torch.cuda.empty_cache()
  222. return model.eval()
  223. class DummyModelLoader(BaseModelLoader):
  224. """Model loader that will set model weights to random values."""
  225. def __init__(self, load_config: LoadConfig):
  226. super().__init__(load_config)
  227. if load_config.model_loader_extra_config:
  228. raise ValueError(f"Model loader extra config is not supported for "
  229. f"load format {load_config.load_format}")
  230. def load_model(self, *, model_config: ModelConfig,
  231. device_config: DeviceConfig,
  232. lora_config: Optional[LoRAConfig],
  233. vision_language_config: Optional[VisionLanguageConfig],
  234. parallel_config: ParallelConfig,
  235. scheduler_config: SchedulerConfig) -> nn.Module:
  236. with set_default_torch_dtype(model_config.dtype):
  237. with torch.device(device_config.device):
  238. model = _initialize_model(model_config, self.load_config,
  239. lora_config, vision_language_config)
  240. # NOTE(woosuk): For accurate performance evaluation, we assign
  241. # random values to the weights.
  242. initialize_dummy_weights(model)
  243. return model.eval()
  244. class TensorizerLoader(BaseModelLoader):
  245. """Model loader using CoreWeave's tensorizer library."""
  246. def __init__(self, load_config: LoadConfig):
  247. super().__init__(load_config)
  248. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  249. self.tensorizer_config = load_config.model_loader_extra_config
  250. else:
  251. self.tensorizer_config = TensorizerConfig(
  252. **load_config.model_loader_extra_config)
  253. def _verify_config(self, model_config: ModelConfig,
  254. parallel_config: ParallelConfig):
  255. self.tensorizer_config.verify_with_model_config(model_config)
  256. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  257. def _get_weights_iterator(
  258. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  259. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  260. return tensorizer_weights_iterator(tensorizer_args)
  261. def _load_model_unserialized(
  262. self, model_config: ModelConfig, device_config: DeviceConfig,
  263. lora_config: Optional[LoRAConfig],
  264. vision_language_config: Optional[VisionLanguageConfig]
  265. ) -> nn.Module:
  266. """Load an unserialized model with tensorizer.
  267. Unserialized here means "not serialized with tensorizer". This
  268. should still be faster than default HuggingFace loading, but will
  269. be slower than loading a tensorizer-serialized model.
  270. """
  271. with set_default_torch_dtype(model_config.dtype):
  272. with torch.device(device_config.device):
  273. model = _initialize_model(model_config, self.load_config,
  274. lora_config, vision_language_config)
  275. model.load_weights(self._get_weights_iterator())
  276. return model.eval()
  277. def _load_model_serialized(
  278. self, model_config: ModelConfig, device_config: DeviceConfig,
  279. lora_config: Optional[LoRAConfig],
  280. vision_language_config: Optional[VisionLanguageConfig]
  281. ) -> nn.Module:
  282. """Load a serialized model with tensorizer.
  283. See the examples/tensorize_aphrodite_model.py example "
  284. script for serializing Aphrodite models."""
  285. with set_default_torch_dtype(model_config.dtype):
  286. with torch.device(device_config.device):
  287. model_class = get_model_architecture(model_config)[0]
  288. quant_config = _get_quantization_config(
  289. model_config, self.load_config)
  290. extra_kwargs = _get_model_initialization_kwargs(
  291. model_class, lora_config, vision_language_config)
  292. extra_kwargs["quant_config"] = quant_config
  293. tensorizer_config = copy.copy(self.tensorizer_config)
  294. tensorizer_config.model_class = model_class
  295. tensorizer_config.hf_config = model_config.hf_config
  296. tensorizer_config.dtype = model_config.dtype
  297. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  298. return model.eval()
  299. def load_model(self, *, model_config: ModelConfig,
  300. device_config: DeviceConfig,
  301. lora_config: Optional[LoRAConfig],
  302. vision_language_config: Optional[VisionLanguageConfig],
  303. parallel_config: ParallelConfig,
  304. scheduler_config: SchedulerConfig) -> nn.Module:
  305. self._verify_config(model_config, parallel_config)
  306. if is_aphrodite_serialized_tensorizer(self.tensorizer_config):
  307. return self._load_model_serialized(model_config, device_config,
  308. lora_config,
  309. vision_language_config)
  310. return self._load_model_unserialized(model_config, device_config,
  311. lora_config,
  312. vision_language_config)
  313. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  314. """Get a model loader based on the load format."""
  315. if isinstance(load_config.load_format, type):
  316. return load_config.load_format(load_config)
  317. if load_config.load_format == LoadFormat.DUMMY:
  318. return DummyModelLoader(load_config)
  319. if load_config.load_format == LoadFormat.TENSORIZER:
  320. return TensorizerLoader(load_config)
  321. return DefaultModelLoader(load_config)