1
0

loader.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # ruff: noqa: SIM117
  2. import copy
  3. import glob
  4. import os
  5. from abc import ABC, abstractmethod
  6. import gc
  7. from contextlib import nullcontext
  8. from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
  9. Type)
  10. import torch
  11. from torch import nn
  12. from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, DeviceConfig,
  13. LoadConfig, LoadFormat, LoRAConfig,
  14. ModelConfig, ParallelConfig,
  15. SchedulerConfig, VisionLanguageConfig)
  16. from aphrodite.modeling.model_loader.tensorizer import (
  17. TensorizerConfig, is_aphrodite_serialized_tensorizer, load_with_tensorizer,
  18. tensorizer_weights_iterator)
  19. from aphrodite.modeling.model_loader.utils import (get_model_architecture,
  20. set_default_torch_dtype)
  21. from aphrodite.modeling.model_loader.weight_utils import (
  22. download_weights_from_hf, filter_files_not_needed_for_inference,
  23. get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
  24. pt_weights_iterator, safetensors_weights_iterator)
  25. from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
  26. from aphrodite.quantization.bitsandbytes import (BNBLinearMethod,
  27. replace_quant_params)
  28. if TYPE_CHECKING:
  29. from aphrodite.modeling.layers.linear import LinearMethodBase
  30. _VISION_MODEL_CLASSES = [
  31. LlavaForConditionalGeneration,
  32. ]
  33. def _get_linear_method(
  34. model_config: ModelConfig,
  35. load_config: LoadConfig) -> Optional["LinearMethodBase"]:
  36. """Get the (maybe quantized) linear method."""
  37. linear_method = None
  38. if model_config.quantization is not None:
  39. quant_config = get_quant_config(model_config, load_config)
  40. capability = torch.cuda.get_device_capability()
  41. capability = capability[0] * 10 + capability[1]
  42. if capability < quant_config.get_min_capability():
  43. raise ValueError(
  44. f"The quantization method {model_config.quantization} is not "
  45. "supported for the current GPU. "
  46. f"Minimum capability: {quant_config.get_min_capability()}. "
  47. f"Current capability: {capability}.")
  48. supported_dtypes = quant_config.get_supported_act_dtypes()
  49. if model_config.dtype not in supported_dtypes:
  50. raise ValueError(
  51. f"{model_config.dtype} is not supported for quantization "
  52. f"method {model_config.quantization}. Supported dtypes: "
  53. f"{supported_dtypes}")
  54. linear_method = quant_config.get_linear_method()
  55. return linear_method
  56. def _get_model_initialization_kwargs(
  57. model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
  58. vision_language_config: Optional[VisionLanguageConfig]
  59. ) -> Dict[str, Any]:
  60. """Get extra kwargs for model initialization."""
  61. extra_kwargs = {}
  62. if hasattr(model_class, "supported_lora_modules"):
  63. extra_kwargs["lora_config"] = lora_config
  64. elif lora_config:
  65. raise ValueError(
  66. f"Model {model_class.__name__} does not support LoRA, "
  67. "but LoRA is enabled. Support for this model may "
  68. "be added in the future. If this is important to you, "
  69. "please open an issue on github.")
  70. elif model_class in _VISION_MODEL_CLASSES:
  71. extra_kwargs["vision_language_config"] = vision_language_config
  72. return extra_kwargs
  73. def _initialize_model(
  74. model_config: ModelConfig, load_config: LoadConfig,
  75. lora_config: Optional[LoRAConfig],
  76. vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
  77. """Initialize a model with the given configurations."""
  78. model_class = get_model_architecture(model_config)[0]
  79. linear_method = _get_linear_method(model_config, load_config)
  80. return model_class(config=model_config.hf_config,
  81. linear_method=linear_method,
  82. **_get_model_initialization_kwargs(
  83. model_class, lora_config, vision_language_config))
  84. class BaseModelLoader(ABC):
  85. """Base class for model loaders."""
  86. def __init__(self, load_config: LoadConfig):
  87. self.load_config = load_config
  88. @abstractmethod
  89. def load_model(self, *, model_config: ModelConfig,
  90. device_config: DeviceConfig,
  91. lora_config: Optional[LoRAConfig],
  92. vision_language_config: Optional[VisionLanguageConfig],
  93. parallel_config: ParallelConfig,
  94. scheduler_config: SchedulerConfig) -> nn.Module:
  95. """Load a model with the given configurations."""
  96. ...
  97. class DefaultModelLoader(BaseModelLoader):
  98. """Model loader that can load different file types from disk."""
  99. def __init__(self, load_config: LoadConfig):
  100. super().__init__(load_config)
  101. if load_config.model_loader_extra_config:
  102. raise ValueError(f"Model loader extra config is not supported for "
  103. f"load format {load_config.load_format}")
  104. def _maybe_download_from_modelscope(
  105. self, model: str, revision: Optional[str]) -> Optional[str]:
  106. """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
  107. True.
  108. Returns the path to the downloaded model, or None if the model is not
  109. downloaded from ModelScope."""
  110. if APHRODITE_USE_MODELSCOPE:
  111. # download model from ModelScope hub,
  112. # lazy import so that modelscope is not required for normal use.
  113. # pylint: disable=C.
  114. from modelscope.hub.snapshot_download import snapshot_download
  115. if not os.path.exists(model):
  116. model_path = snapshot_download(
  117. model_id=model,
  118. cache_dir=self.load_config.download_dir,
  119. revision=revision)
  120. else:
  121. model_path = model
  122. return model_path
  123. return None
  124. def _prepare_weights(self, model_name_or_path: str,
  125. revision: Optional[str],
  126. fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
  127. """Prepare weights for the model.
  128. If the model is not local, it will be downloaded."""
  129. model_name_or_path = self._maybe_download_from_modelscope(
  130. model_name_or_path, revision) or model_name_or_path
  131. is_local = os.path.isdir(model_name_or_path)
  132. load_format = self.load_config.load_format
  133. use_safetensors = False
  134. # Some quantized models use .pt files for storing the weights.
  135. if load_format == LoadFormat.AUTO:
  136. allow_patterns = ["*.safetensors", "*.bin"]
  137. elif load_format == LoadFormat.SAFETENSORS:
  138. use_safetensors = True
  139. allow_patterns = ["*.safetensors"]
  140. elif load_format == LoadFormat.PT:
  141. allow_patterns = ["*.pt"]
  142. elif load_format == LoadFormat.NPCACHE:
  143. allow_patterns = ["*.bin"]
  144. else:
  145. raise ValueError(f"Unknown load_format: {load_format}")
  146. if fall_back_to_pt:
  147. allow_patterns += ["*.pt"]
  148. if not is_local:
  149. hf_folder = download_weights_from_hf(model_name_or_path,
  150. self.load_config.download_dir,
  151. allow_patterns, revision)
  152. else:
  153. hf_folder = model_name_or_path
  154. hf_weights_files: List[str] = []
  155. for pattern in allow_patterns:
  156. hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
  157. if len(hf_weights_files) > 0:
  158. if pattern == "*.safetensors":
  159. use_safetensors = True
  160. break
  161. if not use_safetensors:
  162. hf_weights_files = filter_files_not_needed_for_inference(
  163. hf_weights_files)
  164. if len(hf_weights_files) == 0:
  165. raise RuntimeError(
  166. f"Cannot find any model weights with `{model_name_or_path}`")
  167. return hf_folder, hf_weights_files, use_safetensors
  168. def _get_weights_iterator(
  169. self, model_name_or_path: str, revision: Optional[str],
  170. fall_back_to_pt: bool
  171. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  172. """Get an iterator for the model weights based on the load format."""
  173. hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
  174. model_name_or_path, revision, fall_back_to_pt)
  175. if self.load_config.load_format == LoadFormat.NPCACHE:
  176. # Currently np_cache only support *.bin checkpoints
  177. assert use_safetensors is False
  178. return np_cache_weights_iterator(model_name_or_path,
  179. self.load_config.download_dir,
  180. hf_folder, hf_weights_files)
  181. if use_safetensors:
  182. return safetensors_weights_iterator(hf_weights_files)
  183. return pt_weights_iterator(hf_weights_files)
  184. def load_model(self, *, model_config: ModelConfig,
  185. device_config: DeviceConfig,
  186. lora_config: Optional[LoRAConfig],
  187. vision_language_config: Optional[VisionLanguageConfig],
  188. parallel_config: ParallelConfig,
  189. scheduler_config: SchedulerConfig) -> nn.Module:
  190. with set_default_torch_dtype(model_config.dtype):
  191. linear_method = _get_linear_method(model_config, self.load_config)
  192. context = torch.device(device_config.device) if not (
  193. isinstance(linear_method, BNBLinearMethod)
  194. and linear_method.quant_config.from_float) else nullcontext()
  195. with context:
  196. model = _initialize_model(model_config, self.load_config,
  197. lora_config, vision_language_config)
  198. model.load_weights(
  199. self._get_weights_iterator(model_config.model,
  200. model_config.revision,
  201. fall_back_to_pt=getattr(
  202. model,
  203. "fall_back_to_pt_during_load",
  204. True)), )
  205. for _, module in model.named_modules():
  206. linear_method = getattr(module, "linear_method", None)
  207. if linear_method is not None:
  208. linear_method.process_weights_after_loading(module)
  209. if hasattr(module, "process_weights_after_loading"):
  210. module.process_weights_after_loading()
  211. if isinstance(linear_method, BNBLinearMethod):
  212. replace_quant_params(
  213. model,
  214. quant_config=linear_method.quant_config,
  215. modules_to_not_convert="lm_head",
  216. )
  217. torch.cuda.synchronize()
  218. if linear_method.quant_config.from_float:
  219. model = model.cuda()
  220. gc.collect()
  221. torch.cuda.empty_cache()
  222. return model.eval()
  223. class DummyModelLoader(BaseModelLoader):
  224. """Model loader that will set model weights to random values."""
  225. def __init__(self, load_config: LoadConfig):
  226. super().__init__(load_config)
  227. if load_config.model_loader_extra_config:
  228. raise ValueError(f"Model loader extra config is not supported for "
  229. f"load format {load_config.load_format}")
  230. def load_model(self, *, model_config: ModelConfig,
  231. device_config: DeviceConfig,
  232. lora_config: Optional[LoRAConfig],
  233. vision_language_config: Optional[VisionLanguageConfig],
  234. parallel_config: ParallelConfig,
  235. scheduler_config: SchedulerConfig) -> nn.Module:
  236. with set_default_torch_dtype(model_config.dtype):
  237. with torch.device(device_config.device):
  238. model = _initialize_model(model_config, self.load_config,
  239. lora_config, vision_language_config)
  240. # NOTE(woosuk): For accurate performance evaluation, we assign
  241. # random values to the weights.
  242. initialize_dummy_weights(model)
  243. return model.eval()
  244. class TensorizerLoader(BaseModelLoader):
  245. """Model loader using CoreWeave's tensorizer library."""
  246. def __init__(self, load_config: LoadConfig):
  247. super().__init__(load_config)
  248. if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
  249. self.tensorizer_config = load_config.model_loader_extra_config
  250. else:
  251. self.tensorizer_config = TensorizerConfig(
  252. **load_config.model_loader_extra_config)
  253. def _verify_config(self, model_config: ModelConfig,
  254. parallel_config: ParallelConfig):
  255. self.tensorizer_config.verify_with_model_config(model_config)
  256. self.tensorizer_config.verify_with_parallel_config(parallel_config)
  257. def _get_weights_iterator(
  258. self) -> Generator[Tuple[str, torch.Tensor], None, None]:
  259. tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
  260. return tensorizer_weights_iterator(tensorizer_args)
  261. def _load_model_unserialized(
  262. self, model_config: ModelConfig, device_config: DeviceConfig,
  263. lora_config: Optional[LoRAConfig],
  264. vision_language_config: Optional[VisionLanguageConfig]
  265. ) -> nn.Module:
  266. """Load an unserialized model with tensorizer.
  267. Unserialized here means "not serialized with tensorizer". This
  268. should still be faster than default HuggingFace loading, but will
  269. be slower than loading a tensorizer-serialized model.
  270. """
  271. with set_default_torch_dtype(model_config.dtype):
  272. with torch.device(device_config.device):
  273. model = _initialize_model(model_config, self.load_config,
  274. lora_config, vision_language_config)
  275. model.load_weights(self._get_weights_iterator())
  276. return model.eval()
  277. def _load_model_serialized(
  278. self, model_config: ModelConfig, device_config: DeviceConfig,
  279. lora_config: Optional[LoRAConfig],
  280. vision_language_config: Optional[VisionLanguageConfig]
  281. ) -> nn.Module:
  282. """Load a serialized model with tensorizer.
  283. See the examples/tensorize_aphrodite_model.py example "
  284. script for serializing Aphrodite models."""
  285. with set_default_torch_dtype(model_config.dtype):
  286. with torch.device(device_config.device):
  287. model_class = get_model_architecture(model_config)[0]
  288. linear_method = _get_linear_method(model_config,
  289. self.load_config)
  290. extra_kwargs = _get_model_initialization_kwargs(
  291. model_class, lora_config, vision_language_config)
  292. extra_kwargs["linear_method"] = linear_method
  293. tensorizer_config = copy.copy(self.tensorizer_config)
  294. tensorizer_config.model_class = model_class
  295. tensorizer_config.hf_config = model_config.hf_config
  296. tensorizer_config.dtype = model_config.dtype
  297. model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
  298. return model.eval()
  299. def load_model(self, *, model_config: ModelConfig,
  300. device_config: DeviceConfig,
  301. lora_config: Optional[LoRAConfig],
  302. vision_language_config: Optional[VisionLanguageConfig],
  303. parallel_config: ParallelConfig,
  304. scheduler_config: SchedulerConfig) -> nn.Module:
  305. self._verify_config(model_config, parallel_config)
  306. if is_aphrodite_serialized_tensorizer(self.tensorizer_config):
  307. return self._load_model_serialized(model_config, device_config,
  308. lora_config,
  309. vision_language_config)
  310. return self._load_model_unserialized(model_config, device_config,
  311. lora_config,
  312. vision_language_config)
  313. def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
  314. """Get a model loader based on the load format."""
  315. if isinstance(load_config.load_format, type):
  316. return load_config.load_format(load_config)
  317. if load_config.load_format == LoadFormat.DUMMY:
  318. return DummyModelLoader(load_config)
  319. if load_config.load_format == LoadFormat.TENSORIZER:
  320. return TensorizerLoader(load_config)
  321. return DefaultModelLoader(load_config)