tensorizer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import argparse
  2. import dataclasses
  3. import io
  4. import os
  5. import re
  6. import time
  7. from dataclasses import dataclass
  8. from functools import partial
  9. from typing import BinaryIO, Generator, Optional, Tuple, Type, Union
  10. import torch
  11. from loguru import logger
  12. from torch import nn
  13. from transformers import PretrainedConfig
  14. from aphrodite.common.config import ModelConfig, ParallelConfig
  15. from aphrodite.engine.aphrodite_engine import AphroditeEngine
  16. from aphrodite.engine.args_tools import EngineArgs
  17. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  18. VocabParallelEmbedding)
  19. from aphrodite.quantization.base_config import QuantizationConfig
  20. tensorizer_error_msg = None
  21. try:
  22. from tensorizer import (DecryptionParams, EncryptionParams,
  23. TensorDeserializer, TensorSerializer)
  24. from tensorizer.stream_io import open_stream
  25. from tensorizer.utils import (convert_bytes, get_mem_usage,
  26. no_init_or_tensor)
  27. _read_stream, _write_stream = (partial(
  28. open_stream,
  29. mode=mode,
  30. ) for mode in ("rb", "wb+"))
  31. except ImportError as e:
  32. tensorizer_error_msg = e
  33. __all__ = [
  34. 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
  35. 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
  36. 'no_init_or_tensor', 'TensorizerConfig'
  37. ]
  38. @dataclass
  39. class TensorizerConfig:
  40. tensorizer_uri: str
  41. aphrodite_tensorized: Optional[bool] = False
  42. verify_hash: Optional[bool] = False
  43. num_readers: Optional[int] = None
  44. encryption_keyfile: Optional[str] = None
  45. s3_access_key_id: Optional[str] = None
  46. s3_secret_access_key: Optional[str] = None
  47. s3_endpoint: Optional[str] = None
  48. model_class: Optional[Type[torch.nn.Module]] = None
  49. hf_config: Optional[PretrainedConfig] = None
  50. dtype: Optional[Union[str, torch.dtype]] = None
  51. _is_sharded: bool = False
  52. def __post_init__(self):
  53. # check if the configuration is for a sharded Aphrodite model
  54. self._is_sharded = isinstance(self.tensorizer_uri, str) \
  55. and re.search(r'%0\dd', self.tensorizer_uri) is not None
  56. def _construct_tensorizer_args(self) -> "TensorizerArgs":
  57. tensorizer_args = {
  58. "tensorizer_uri": self.tensorizer_uri,
  59. "aphrodite_tensorized": self.aphrodite_tensorized,
  60. "verify_hash": self.verify_hash,
  61. "num_readers": self.num_readers,
  62. "encryption_keyfile": self.encryption_keyfile,
  63. "s3_access_key_id": self.s3_access_key_id,
  64. "s3_secret_access_key": self.s3_secret_access_key,
  65. "s3_endpoint": self.s3_endpoint,
  66. }
  67. return TensorizerArgs(**tensorizer_args)
  68. def verify_with_parallel_config(
  69. self,
  70. parallel_config: "ParallelConfig",
  71. ) -> None:
  72. if parallel_config.tensor_parallel_size > 1 \
  73. and not self._is_sharded:
  74. raise ValueError(
  75. "For a sharded model, tensorizer_uri should include a"
  76. " string format template like '%04d' to be formatted"
  77. " with the rank of the shard")
  78. def verify_with_model_config(self, model_config: "ModelConfig") -> None:
  79. if (model_config.quantization is not None
  80. and self.tensorizer_uri is not None):
  81. logger.warning(
  82. "Loading a model using Tensorizer with quantization on "
  83. "aphrodite is unstable and may lead to errors.")
  84. def load_with_tensorizer(tensorizer_config: TensorizerConfig,
  85. **extra_kwargs) -> nn.Module:
  86. tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
  87. return tensorizer.deserialize()
  88. @dataclass
  89. class TensorizerArgs:
  90. tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
  91. bytes, os.PathLike, int]
  92. aphrodite_tensorized: Optional[bool] = False
  93. verify_hash: Optional[bool] = False
  94. num_readers: Optional[int] = None
  95. encryption_keyfile: Optional[str] = None
  96. s3_access_key_id: Optional[str] = None
  97. s3_secret_access_key: Optional[str] = None
  98. s3_endpoint: Optional[str] = None
  99. """
  100. Args for the TensorizerAgent class. These are used to configure the behavior
  101. of the TensorDeserializer when loading tensors from a serialized model.
  102. Args:
  103. tensorizer_uri: Path to serialized model tensors. Can be a local file
  104. path or a S3 URI.
  105. aphrodite_tensorized: If True, indicates that the serialized model is a
  106. aphrodite model. This is used to determine the behavior of the
  107. TensorDeserializer when loading tensors from a serialized model.
  108. It is far faster to deserialize a aphrodite model as it utilizes
  109. ttensorizer's optimized GPU loading. Note that this is now
  110. deprecated, as serialized Aphrodite models are now automatically
  111. inferred as Aphrodite models.
  112. verify_hash: If True, the hashes of each tensor will be verified against
  113. the hashes stored in the metadata. A `HashMismatchError` will be
  114. raised if any of the hashes do not match.
  115. num_readers: Controls how many threads are allowed to read concurrently
  116. from the source file. Default is `None`, which will dynamically set
  117. the number of readers based on the number of available
  118. resources and model size. This greatly increases performance.
  119. encryption_keyfile: File path to a binary file containing a
  120. binary key to use for decryption. `None` (the default) means
  121. no decryption. See the example script in
  122. examples/tensorize_aphrodite_model.py.
  123. s3_access_key_id: The access key for the S3 bucket. Can also be set via
  124. the S3_ACCESS_KEY_ID environment variable.
  125. s3_secret_access_key: The secret access key for the S3 bucket. Can also
  126. be set via the S3_SECRET_ACCESS_KEY environment variable.
  127. s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
  128. S3_ENDPOINT_URL environment variable.
  129. """
  130. def __post_init__(self):
  131. self.file_obj = self.tensorizer_uri
  132. self.s3_access_key_id = (self.s3_access_key_id
  133. or os.environ.get("S3_ACCESS_KEY_ID")) or None
  134. self.s3_secret_access_key = (
  135. self.s3_secret_access_key
  136. or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
  137. self.s3_endpoint = (self.s3_endpoint
  138. or os.environ.get("S3_ENDPOINT_URL")) or None
  139. self.stream_params = {
  140. "s3_access_key_id": self.s3_access_key_id,
  141. "s3_secret_access_key": self.s3_secret_access_key,
  142. "s3_endpoint": self.s3_endpoint,
  143. }
  144. self.deserializer_params = {
  145. "verify_hash": self.verify_hash,
  146. "encryption": self.encryption_keyfile,
  147. "num_readers": self.num_readers
  148. }
  149. if self.encryption_keyfile:
  150. with open_stream(
  151. self.encryption_keyfile,
  152. **self.stream_params,
  153. ) as stream:
  154. key = stream.read()
  155. decryption_params = DecryptionParams.from_key(key)
  156. self.deserializer_params['encryption'] = decryption_params
  157. @staticmethod
  158. def add_cli_args(
  159. parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
  160. """Tensorizer CLI arguments"""
  161. # Tensorizer options arg group
  162. group = parser.add_argument_group(
  163. 'tensorizer options',
  164. description=('Options for configuring the behavior of the'
  165. ' tensorizer deserializer when '
  166. 'load_format=tensorizer is specified when '
  167. 'initializing an AphroditeEngine, either via the CLI '
  168. 'when running the Aphrodite OpenAI inference server '
  169. 'with a JSON string passed to '
  170. '--model-loader-extra-config or as arguments given '
  171. 'to TensorizerConfig when passed to '
  172. 'model_loader_extra_config in the constructor '
  173. 'for AphroditeEngine.'))
  174. group.add_argument(
  175. "--tensorizer-uri",
  176. help="Path to serialized model tensors. Can be a local file path,"
  177. " or an HTTP(S) or S3 URI.",
  178. )
  179. group.add_argument(
  180. "--verify-hash",
  181. action="store_true",
  182. help="If enabled, the hashes of each tensor will be verified"
  183. " against the hashes stored in the file metadata. An exception"
  184. " will be raised if any of the hashes do not match.",
  185. )
  186. group.add_argument(
  187. "--encryption-keyfile",
  188. default=None,
  189. help="The file path to a binary file containing a binary key to "
  190. "use for decryption. Can be a file path or S3 network URI.")
  191. group.add_argument(
  192. "--num-readers",
  193. default=None,
  194. type=int,
  195. help="Controls how many threads are allowed to read concurrently "
  196. "from the source file. Default is `None`, which will dynamically "
  197. "set the number of readers based on the available resources "
  198. "and model size. This greatly increases performance.")
  199. group.add_argument(
  200. "--s3-access-key-id",
  201. default=None,
  202. help="The access key for the S3 bucket. Can also be set via the "
  203. "S3_ACCESS_KEY_ID environment variable.",
  204. )
  205. group.add_argument(
  206. "--s3-secret-access-key",
  207. default=None,
  208. help="The secret access key for the S3 bucket. Can also be set via "
  209. "the S3_SECRET_ACCESS_KEY environment variable.",
  210. )
  211. group.add_argument(
  212. "--s3-endpoint",
  213. default=None,
  214. help="The endpoint for the S3 bucket. Can also be set via the "
  215. "S3_ENDPOINT_URL environment variable.",
  216. )
  217. return parser
  218. @classmethod
  219. def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
  220. attrs = [attr.name for attr in dataclasses.fields(cls)]
  221. tensorizer_args = cls(**{
  222. attr: getattr(args, attr)
  223. for attr in attrs if hasattr(args, attr)
  224. })
  225. return tensorizer_args
  226. class TensorizerAgent:
  227. """
  228. A class for performing tensorizer deserializations specifically for
  229. aphrodite models using plaid_mode. Uses TensorizerArgs to configure the
  230. behavior of the TensorDeserializer when loading tensors from a serialized
  231. model. For deserializations of HuggingFace models, TensorDeserializer is
  232. instead used as an iterator directly in the func hf_model_weights_iterator
  233. in aphrodite/modeling/model_loader/weight_utils.py
  234. """
  235. def __init__(self, tensorizer_config: TensorizerConfig,
  236. quant_config: QuantizationConfig, **extra_kwargs):
  237. if tensorizer_error_msg is not None:
  238. raise ImportError(
  239. "Tensorizer is not installed. Please install tensorizer "
  240. "to use this feature with "
  241. "`pip install aphrodite-engine[tensorizer]`. "
  242. "Error message: {}".format(tensorizer_error_msg))
  243. self.tensorizer_config = tensorizer_config
  244. self.tensorizer_args = (
  245. self.tensorizer_config._construct_tensorizer_args())
  246. self.extra_kwargs = extra_kwargs
  247. if extra_kwargs.get("quant_config", None) is not None:
  248. self.quant_config = extra_kwargs["quant_config"]
  249. else:
  250. self.quant_config = quant_config
  251. self.model = self._init_model()
  252. def _init_model(self):
  253. model_args = self.tensorizer_config.hf_config
  254. model_args.torch_dtype = self.tensorizer_config.dtype
  255. with no_init_or_tensor():
  256. return self.tensorizer_config.model_class(
  257. config=model_args,
  258. quant_config=self.quant_config,
  259. **self.extra_kwargs)
  260. def _resize_lora_embeddings(self):
  261. """Modify LoRA embedding layers to use bigger tensors
  262. to allow for adapter added tokens."""
  263. for child in self.model.modules():
  264. if (isinstance(child, VocabParallelEmbedding)
  265. and child.weight.shape[0] <
  266. child.num_embeddings_per_partition):
  267. new_weight = torch.empty(child.num_embeddings_per_partition,
  268. child.embedding_dim,
  269. dtype=child.weight.dtype,
  270. device=child.weight.device)
  271. new_weight[:child.weight.shape[0]].copy_(child.weight.data)
  272. new_weight[child.weight.shape[0]:].fill_(0)
  273. child.weight.data = new_weight
  274. def _check_tensors_on_meta_device(self):
  275. for tensor in self.model.state_dict().values():
  276. if tensor.device.type == 'meta':
  277. raise ValueError(
  278. "The serialized model contains tensors on the meta device,"
  279. " indicating that some tensors were not loaded properly."
  280. " Please check that the parameters of the model being"
  281. " specified match that of the serialized model, such as"
  282. " its quantization.")
  283. def deserialize(self):
  284. """
  285. Deserialize the model using the TensorDeserializer. This method is
  286. specifically for Aphrodite models using tensorizer's plaid_mode.
  287. The deserializer makes use of tensorizer_args.stream_params
  288. to configure the behavior of the stream when loading tensors from a
  289. serialized model. The deserializer_params are used to configure the
  290. behavior of the TensorDeserializer when loading tensors themselves.
  291. Documentation on these params can be found in TensorizerArgs
  292. Returns:
  293. nn.Module: The deserialized model.
  294. """
  295. before_mem = get_mem_usage()
  296. start = time.perf_counter()
  297. with _read_stream(
  298. self.tensorizer_config.tensorizer_uri,
  299. **self.tensorizer_args.stream_params
  300. ) as stream, TensorDeserializer(
  301. stream,
  302. dtype=self.tensorizer_config.dtype,
  303. device=f'cuda:{torch.cuda.current_device()}',
  304. **self.tensorizer_args.deserializer_params) as deserializer:
  305. deserializer.load_into_module(self.model)
  306. end = time.perf_counter()
  307. total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
  308. duration = end - start
  309. per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
  310. after_mem = get_mem_usage()
  311. deserializer.close()
  312. logger.info(f"Deserialized {total_bytes_str} in "
  313. f"{end - start:0.2f}s, {per_second}/s")
  314. logger.info(f"Memory usage before: {before_mem}")
  315. logger.info(f"Memory usage after: {after_mem}")
  316. self._check_tensors_on_meta_device()
  317. self._resize_lora_embeddings()
  318. del self.model.aphrodite_tensorized_marker
  319. return self.model.eval()
  320. def tensorizer_weights_iterator(
  321. tensorizer_args: "TensorizerArgs"
  322. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  323. logger.warning(
  324. "Deserializing HuggingFace models is not optimized for "
  325. "loading on Aphrodite, as tensorizer is forced to load to CPU. "
  326. "Consider deserializing a Aphrodite model instead for faster "
  327. "load times. See the examples/tensorize_aphrodite_model.py example "
  328. "script for serializing Aphrodite models.")
  329. deserializer_args = tensorizer_args.deserializer_params
  330. stream_params = tensorizer_args.stream_params
  331. stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
  332. with TensorDeserializer(stream, **deserializer_args,
  333. device="cpu") as state:
  334. for name, param in state.items():
  335. yield name, param
  336. del state
  337. def is_aphrodite_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
  338. """
  339. Infer if the model is a Aphrodite model by checking the weights for
  340. a Aphrodite tensorized marker.
  341. Args:
  342. tensorizer_config: The TensorizerConfig object containing the
  343. tensorizer_uri to the serialized model.
  344. Returns:
  345. bool: True if the model is a Aphrodite model, False otherwise.
  346. """
  347. tensorizer_args = tensorizer_config._construct_tensorizer_args()
  348. deserializer = TensorDeserializer(open_stream(
  349. tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
  350. **tensorizer_args.deserializer_params,
  351. lazy_load=True)
  352. if tensorizer_config.aphrodite_tensorized:
  353. logger.warning(
  354. "Please note that newly serialized Aphrodite models are "
  355. "automatically inferred as Aphrodite models, so setting "
  356. "aphrodite_tensorized=True is only necessary for models serialized "
  357. "prior to this change.")
  358. return True
  359. if (".aphrodite_tensorized_marker" in deserializer):
  360. return True
  361. return False
  362. def serialize_aphrodite_model(
  363. model: nn.Module,
  364. tensorizer_config: TensorizerConfig,
  365. ) -> nn.Module:
  366. model.register_parameter(
  367. "aphrodite_tensorized_marker",
  368. nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
  369. tensorizer_args = tensorizer_config._construct_tensorizer_args()
  370. encryption_params = None
  371. if (keyfile := tensorizer_config.encryption_keyfile) is not None:
  372. with open(keyfile, "rb") as f:
  373. key = f.read()
  374. encryption_params = EncryptionParams(key=key)
  375. output_file = tensorizer_args.tensorizer_uri
  376. if tensorizer_config._is_sharded:
  377. from aphrodite.distributed import get_tensor_model_parallel_rank
  378. output_file = output_file % get_tensor_model_parallel_rank()
  379. with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
  380. serializer = TensorSerializer(stream, encryption=encryption_params)
  381. serializer.write_module(model)
  382. serializer.close()
  383. logger.info(f"Successfully serialized model to {str(output_file)}")
  384. return model
  385. def tensorize_aphrodite_model(engine_args: EngineArgs,
  386. tensorizer_config: TensorizerConfig,
  387. generate_keyfile: bool = True):
  388. """Utility to load a model and then serialize it with Tensorizer
  389. Intended to be used separately from running a aphrodite server since it
  390. creates its own Engine instance.
  391. """
  392. engine_config = engine_args.create_engine_config()
  393. tensorizer_config.verify_with_model_config(engine_config.model_config)
  394. tensorizer_config.verify_with_parallel_config(
  395. engine_config.parallel_config)
  396. # generate the encryption key before creating the engine to support
  397. # sharding
  398. if generate_keyfile and (keyfile :=
  399. tensorizer_config.encryption_keyfile) is not None:
  400. encryption_params = EncryptionParams.random()
  401. with _write_stream(
  402. keyfile,
  403. s3_access_key_id=tensorizer_config.s3_access_key_id,
  404. s3_secret_access_key=tensorizer_config.s3_secret_access_key,
  405. s3_endpoint=tensorizer_config.s3_endpoint,
  406. ) as stream:
  407. stream.write(encryption_params.key)
  408. engine = AphroditeEngine.from_engine_args(engine_args)
  409. if tensorizer_config._is_sharded:
  410. # if the engine is a distributed engine (for tensor parallel) then each
  411. # worker shard needs to serialize its part of the model.
  412. engine.model_executor._run_workers(
  413. "save_tensorized_model",
  414. tensorizer_config=tensorizer_config,
  415. )
  416. else:
  417. # with a single worker, we can get to the underlying model directly
  418. serialize_aphrodite_model(
  419. engine.model_executor.driver_worker.model_runner.model,
  420. tensorizer_config,
  421. )