tensorizer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import argparse
  2. import dataclasses
  3. import io
  4. import os
  5. import re
  6. import time
  7. from dataclasses import dataclass
  8. from functools import partial
  9. from typing import BinaryIO, Generator, Optional, Tuple, Type, Union
  10. import torch
  11. from loguru import logger
  12. from torch import nn
  13. from transformers import PretrainedConfig
  14. import aphrodite.common.envs as envs
  15. from aphrodite.common.config import ModelConfig, ParallelConfig
  16. from aphrodite.engine.aphrodite_engine import AphroditeEngine
  17. from aphrodite.engine.args_tools import EngineArgs
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. VocabParallelEmbedding)
  20. from aphrodite.quantization.base_config import QuantizationConfig
  21. tensorizer_error_msg = None
  22. try:
  23. from tensorizer import (DecryptionParams, EncryptionParams,
  24. TensorDeserializer, TensorSerializer)
  25. from tensorizer.stream_io import open_stream
  26. from tensorizer.utils import (convert_bytes, get_mem_usage,
  27. no_init_or_tensor)
  28. _read_stream, _write_stream = (partial(
  29. open_stream,
  30. mode=mode,
  31. ) for mode in ("rb", "wb+"))
  32. except ImportError as e:
  33. tensorizer_error_msg = e
  34. __all__ = [
  35. 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
  36. 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
  37. 'no_init_or_tensor', 'TensorizerConfig'
  38. ]
  39. @dataclass
  40. class TensorizerConfig:
  41. tensorizer_uri: str
  42. aphrodite_tensorized: Optional[bool] = False
  43. verify_hash: Optional[bool] = False
  44. num_readers: Optional[int] = None
  45. encryption_keyfile: Optional[str] = None
  46. s3_access_key_id: Optional[str] = None
  47. s3_secret_access_key: Optional[str] = None
  48. s3_endpoint: Optional[str] = None
  49. model_class: Optional[Type[torch.nn.Module]] = None
  50. hf_config: Optional[PretrainedConfig] = None
  51. dtype: Optional[Union[str, torch.dtype]] = None
  52. _is_sharded: bool = False
  53. def __post_init__(self):
  54. # check if the configuration is for a sharded Aphrodite model
  55. self._is_sharded = isinstance(self.tensorizer_uri, str) \
  56. and re.search(r'%0\dd', self.tensorizer_uri) is not None
  57. def _construct_tensorizer_args(self) -> "TensorizerArgs":
  58. tensorizer_args = {
  59. "tensorizer_uri": self.tensorizer_uri,
  60. "aphrodite_tensorized": self.aphrodite_tensorized,
  61. "verify_hash": self.verify_hash,
  62. "num_readers": self.num_readers,
  63. "encryption_keyfile": self.encryption_keyfile,
  64. "s3_access_key_id": self.s3_access_key_id,
  65. "s3_secret_access_key": self.s3_secret_access_key,
  66. "s3_endpoint": self.s3_endpoint,
  67. }
  68. return TensorizerArgs(**tensorizer_args)
  69. def verify_with_parallel_config(
  70. self,
  71. parallel_config: "ParallelConfig",
  72. ) -> None:
  73. if parallel_config.tensor_parallel_size > 1 \
  74. and not self._is_sharded:
  75. raise ValueError(
  76. "For a sharded model, tensorizer_uri should include a"
  77. " string format template like '%04d' to be formatted"
  78. " with the rank of the shard")
  79. def verify_with_model_config(self, model_config: "ModelConfig") -> None:
  80. if (model_config.quantization is not None
  81. and self.tensorizer_uri is not None):
  82. logger.warning(
  83. "Loading a model using Tensorizer with quantization on "
  84. "aphrodite is unstable and may lead to errors.")
  85. def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
  86. if tensorizer_args is None:
  87. tensorizer_args = self._construct_tensorizer_args()
  88. return open_stream(self.tensorizer_uri,
  89. **tensorizer_args.stream_params)
  90. def load_with_tensorizer(tensorizer_config: TensorizerConfig,
  91. **extra_kwargs) -> nn.Module:
  92. tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
  93. return tensorizer.deserialize()
  94. @dataclass
  95. class TensorizerArgs:
  96. tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
  97. bytes, os.PathLike, int]
  98. aphrodite_tensorized: Optional[bool] = False
  99. verify_hash: Optional[bool] = False
  100. num_readers: Optional[int] = None
  101. encryption_keyfile: Optional[str] = None
  102. s3_access_key_id: Optional[str] = None
  103. s3_secret_access_key: Optional[str] = None
  104. s3_endpoint: Optional[str] = None
  105. """
  106. Args for the TensorizerAgent class. These are used to configure the behavior
  107. of the TensorDeserializer when loading tensors from a serialized model.
  108. Args:
  109. tensorizer_uri: Path to serialized model tensors. Can be a local file
  110. path or a S3 URI.
  111. aphrodite_tensorized: If True, indicates that the serialized model is a
  112. aphrodite model. This is used to determine the behavior of the
  113. TensorDeserializer when loading tensors from a serialized model.
  114. It is far faster to deserialize a aphrodite model as it utilizes
  115. ttensorizer's optimized GPU loading. Note that this is now
  116. deprecated, as serialized Aphrodite models are now automatically
  117. inferred as Aphrodite models.
  118. verify_hash: If True, the hashes of each tensor will be verified against
  119. the hashes stored in the metadata. A `HashMismatchError` will be
  120. raised if any of the hashes do not match.
  121. num_readers: Controls how many threads are allowed to read concurrently
  122. from the source file. Default is `None`, which will dynamically set
  123. the number of readers based on the number of available
  124. resources and model size. This greatly increases performance.
  125. encryption_keyfile: File path to a binary file containing a
  126. binary key to use for decryption. `None` (the default) means
  127. no decryption. See the example script in
  128. examples/tensorize_aphrodite_model.py.
  129. s3_access_key_id: The access key for the S3 bucket. Can also be set via
  130. the S3_ACCESS_KEY_ID environment variable.
  131. s3_secret_access_key: The secret access key for the S3 bucket. Can also
  132. be set via the S3_SECRET_ACCESS_KEY environment variable.
  133. s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
  134. S3_ENDPOINT_URL environment variable.
  135. """
  136. def __post_init__(self):
  137. self.file_obj = self.tensorizer_uri
  138. self.s3_access_key_id = (self.s3_access_key_id
  139. or envs.S3_ACCESS_KEY_ID) or None
  140. self.s3_secret_access_key = (
  141. self.s3_secret_access_key
  142. or envs.S3_SECRET_ACCESS_KEY) or None
  143. self.s3_endpoint = (self.s3_endpoint
  144. or envs.S3_ENDPOINT_URL) or None
  145. self.stream_params = {
  146. "s3_access_key_id": self.s3_access_key_id,
  147. "s3_secret_access_key": self.s3_secret_access_key,
  148. "s3_endpoint": self.s3_endpoint,
  149. }
  150. self.deserializer_params = {
  151. "verify_hash": self.verify_hash,
  152. "encryption": self.encryption_keyfile,
  153. "num_readers": self.num_readers
  154. }
  155. if self.encryption_keyfile:
  156. with open_stream(
  157. self.encryption_keyfile,
  158. **self.stream_params,
  159. ) as stream:
  160. key = stream.read()
  161. decryption_params = DecryptionParams.from_key(key)
  162. self.deserializer_params['encryption'] = decryption_params
  163. @staticmethod
  164. def add_cli_args(
  165. parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
  166. """Tensorizer CLI arguments"""
  167. # Tensorizer options arg group
  168. group = parser.add_argument_group(
  169. 'tensorizer options',
  170. description=('Options for configuring the behavior of the'
  171. ' tensorizer deserializer when '
  172. 'load_format=tensorizer is specified when '
  173. 'initializing an AphroditeEngine, either via the CLI '
  174. 'when running the Aphrodite OpenAI inference server '
  175. 'with a JSON string passed to '
  176. '--model-loader-extra-config or as arguments given '
  177. 'to TensorizerConfig when passed to '
  178. 'model_loader_extra_config in the constructor '
  179. 'for AphroditeEngine.'))
  180. group.add_argument(
  181. "--tensorizer-uri",
  182. help="Path to serialized model tensors. Can be a local file path,"
  183. " or an HTTP(S) or S3 URI.",
  184. )
  185. group.add_argument(
  186. "--verify-hash",
  187. action="store_true",
  188. help="If enabled, the hashes of each tensor will be verified"
  189. " against the hashes stored in the file metadata. An exception"
  190. " will be raised if any of the hashes do not match.",
  191. )
  192. group.add_argument(
  193. "--encryption-keyfile",
  194. default=None,
  195. help="The file path to a binary file containing a binary key to "
  196. "use for decryption. Can be a file path or S3 network URI.")
  197. group.add_argument(
  198. "--num-readers",
  199. default=None,
  200. type=int,
  201. help="Controls how many threads are allowed to read concurrently "
  202. "from the source file. Default is `None`, which will dynamically "
  203. "set the number of readers based on the available resources "
  204. "and model size. This greatly increases performance.")
  205. group.add_argument(
  206. "--s3-access-key-id",
  207. default=None,
  208. help="The access key for the S3 bucket. Can also be set via the "
  209. "S3_ACCESS_KEY_ID environment variable.",
  210. )
  211. group.add_argument(
  212. "--s3-secret-access-key",
  213. default=None,
  214. help="The secret access key for the S3 bucket. Can also be set via "
  215. "the S3_SECRET_ACCESS_KEY environment variable.",
  216. )
  217. group.add_argument(
  218. "--s3-endpoint",
  219. default=None,
  220. help="The endpoint for the S3 bucket. Can also be set via the "
  221. "S3_ENDPOINT_URL environment variable.",
  222. )
  223. return parser
  224. @classmethod
  225. def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
  226. attrs = [attr.name for attr in dataclasses.fields(cls)]
  227. tensorizer_args = cls(**{
  228. attr: getattr(args, attr)
  229. for attr in attrs if hasattr(args, attr)
  230. })
  231. return tensorizer_args
  232. class TensorizerAgent:
  233. """
  234. A class for performing tensorizer deserializations specifically for
  235. aphrodite models using plaid_mode. Uses TensorizerArgs to configure the
  236. behavior of the TensorDeserializer when loading tensors from a serialized
  237. model. For deserializations of HuggingFace models, TensorDeserializer is
  238. instead used as an iterator directly in the func hf_model_weights_iterator
  239. in aphrodite/modeling/model_loader/weight_utils.py
  240. """
  241. def __init__(self, tensorizer_config: TensorizerConfig,
  242. quant_config: QuantizationConfig, **extra_kwargs):
  243. if tensorizer_error_msg is not None:
  244. raise ImportError(
  245. "Tensorizer is not installed. Please install tensorizer "
  246. "to use this feature with "
  247. "`pip install aphrodite-engine[tensorizer]`. "
  248. "Error message: {}".format(tensorizer_error_msg))
  249. self.tensorizer_config = tensorizer_config
  250. self.tensorizer_args = (
  251. self.tensorizer_config._construct_tensorizer_args())
  252. self.extra_kwargs = extra_kwargs
  253. if extra_kwargs.get("quant_config", None) is not None:
  254. self.quant_config = extra_kwargs["quant_config"]
  255. else:
  256. self.quant_config = quant_config
  257. self.model = self._init_model()
  258. def _init_model(self):
  259. model_args = self.tensorizer_config.hf_config
  260. model_args.torch_dtype = self.tensorizer_config.dtype
  261. with no_init_or_tensor():
  262. return self.tensorizer_config.model_class(
  263. config=model_args,
  264. quant_config=self.quant_config,
  265. **self.extra_kwargs)
  266. def _resize_lora_embeddings(self):
  267. """Modify LoRA embedding layers to use bigger tensors
  268. to allow for adapter added tokens."""
  269. for child in self.model.modules():
  270. if (isinstance(child, VocabParallelEmbedding)
  271. and child.weight.shape[0] <
  272. child.num_embeddings_per_partition):
  273. new_weight = torch.empty(child.num_embeddings_per_partition,
  274. child.embedding_dim,
  275. dtype=child.weight.dtype,
  276. device=child.weight.device)
  277. new_weight[:child.weight.shape[0]].copy_(child.weight.data)
  278. new_weight[child.weight.shape[0]:].fill_(0)
  279. child.weight.data = new_weight
  280. def _check_tensors_on_meta_device(self):
  281. for tensor in self.model.state_dict().values():
  282. if tensor.device.type == 'meta':
  283. raise ValueError(
  284. "The serialized model contains tensors on the meta device,"
  285. " indicating that some tensors were not loaded properly."
  286. " Please check that the parameters of the model being"
  287. " specified match that of the serialized model, such as"
  288. " its quantization.")
  289. def deserialize(self):
  290. """
  291. Deserialize the model using the TensorDeserializer. This method is
  292. specifically for Aphrodite models using tensorizer's plaid_mode.
  293. The deserializer makes use of tensorizer_args.stream_params
  294. to configure the behavior of the stream when loading tensors from a
  295. serialized model. The deserializer_params are used to configure the
  296. behavior of the TensorDeserializer when loading tensors themselves.
  297. Documentation on these params can be found in TensorizerArgs
  298. Returns:
  299. nn.Module: The deserialized model.
  300. """
  301. before_mem = get_mem_usage()
  302. start = time.perf_counter()
  303. with _read_stream(
  304. self.tensorizer_config.tensorizer_uri,
  305. **self.tensorizer_args.stream_params
  306. ) as stream, TensorDeserializer(
  307. stream,
  308. dtype=self.tensorizer_config.dtype,
  309. device=f'cuda:{torch.cuda.current_device()}',
  310. **self.tensorizer_args.deserializer_params) as deserializer:
  311. deserializer.load_into_module(self.model)
  312. end = time.perf_counter()
  313. total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
  314. duration = end - start
  315. per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
  316. after_mem = get_mem_usage()
  317. deserializer.close()
  318. logger.info(f"Deserialized {total_bytes_str} in "
  319. f"{end - start:0.2f}s, {per_second}/s")
  320. logger.info(f"Memory usage before: {before_mem}")
  321. logger.info(f"Memory usage after: {after_mem}")
  322. self._check_tensors_on_meta_device()
  323. self._resize_lora_embeddings()
  324. del self.model.aphrodite_tensorized_marker
  325. return self.model.eval()
  326. def tensorizer_weights_iterator(
  327. tensorizer_args: "TensorizerArgs"
  328. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  329. logger.warning(
  330. "Deserializing HuggingFace models is not optimized for "
  331. "loading on Aphrodite, as tensorizer is forced to load to CPU. "
  332. "Consider deserializing a Aphrodite model instead for faster "
  333. "load times. See the examples/tensorize_aphrodite_model.py example "
  334. "script for serializing Aphrodite models.")
  335. deserializer_args = tensorizer_args.deserializer_params
  336. stream_params = tensorizer_args.stream_params
  337. stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
  338. with TensorDeserializer(stream, **deserializer_args,
  339. device="cpu") as state:
  340. for name, param in state.items():
  341. yield name, param
  342. del state
  343. def is_aphrodite_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
  344. """
  345. Infer if the model is a Aphrodite model by checking the weights for
  346. a Aphrodite tensorized marker.
  347. Args:
  348. tensorizer_config: The TensorizerConfig object containing the
  349. tensorizer_uri to the serialized model.
  350. Returns:
  351. bool: True if the model is a Aphrodite model, False otherwise.
  352. """
  353. tensorizer_args = tensorizer_config._construct_tensorizer_args()
  354. deserializer = TensorDeserializer(open_stream(
  355. tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
  356. **tensorizer_args.deserializer_params,
  357. lazy_load=True)
  358. if tensorizer_config.aphrodite_tensorized:
  359. logger.warning(
  360. "Please note that newly serialized Aphrodite models are "
  361. "automatically inferred as Aphrodite models, so setting "
  362. "aphrodite_tensorized=True is only necessary for models serialized "
  363. "prior to this change.")
  364. return True
  365. if (".aphrodite_tensorized_marker" in deserializer):
  366. return True
  367. return False
  368. def serialize_aphrodite_model(
  369. model: nn.Module,
  370. tensorizer_config: TensorizerConfig,
  371. ) -> nn.Module:
  372. model.register_parameter(
  373. "aphrodite_tensorized_marker",
  374. nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
  375. tensorizer_args = tensorizer_config._construct_tensorizer_args()
  376. encryption_params = None
  377. if (keyfile := tensorizer_config.encryption_keyfile) is not None:
  378. with open(keyfile, "rb") as f:
  379. key = f.read()
  380. encryption_params = EncryptionParams(key=key)
  381. output_file = tensorizer_args.tensorizer_uri
  382. if tensorizer_config._is_sharded:
  383. from aphrodite.distributed import get_tensor_model_parallel_rank
  384. output_file = output_file % get_tensor_model_parallel_rank()
  385. with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
  386. serializer = TensorSerializer(stream, encryption=encryption_params)
  387. serializer.write_module(model)
  388. serializer.close()
  389. logger.info(f"Successfully serialized model to {str(output_file)}")
  390. return model
  391. def tensorize_aphrodite_model(engine_args: EngineArgs,
  392. tensorizer_config: TensorizerConfig,
  393. generate_keyfile: bool = True):
  394. """Utility to load a model and then serialize it with Tensorizer
  395. Intended to be used separately from running a aphrodite server since it
  396. creates its own Engine instance.
  397. """
  398. engine_config = engine_args.create_engine_config()
  399. tensorizer_config.verify_with_model_config(engine_config.model_config)
  400. tensorizer_config.verify_with_parallel_config(
  401. engine_config.parallel_config)
  402. # generate the encryption key before creating the engine to support
  403. # sharding
  404. if generate_keyfile and (keyfile :=
  405. tensorizer_config.encryption_keyfile) is not None:
  406. encryption_params = EncryptionParams.random()
  407. with _write_stream(
  408. keyfile,
  409. s3_access_key_id=tensorizer_config.s3_access_key_id,
  410. s3_secret_access_key=tensorizer_config.s3_secret_access_key,
  411. s3_endpoint=tensorizer_config.s3_endpoint,
  412. ) as stream:
  413. stream.write(encryption_params.key)
  414. engine = AphroditeEngine.from_engine_args(engine_args)
  415. if tensorizer_config._is_sharded:
  416. # if the engine is a distributed engine (for tensor parallel) then each
  417. # worker shard needs to serialize its part of the model.
  418. engine.model_executor._run_workers(
  419. "save_tensorized_model",
  420. tensorizer_config=tensorizer_config,
  421. )
  422. else:
  423. # with a single worker, we can get to the underlying model directly
  424. serialize_aphrodite_model(
  425. engine.model_executor.driver_worker.model_runner.model,
  426. tensorizer_config,
  427. )