tensorizer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import argparse
  2. import dataclasses
  3. import io
  4. import os
  5. import time
  6. import typing
  7. from dataclasses import dataclass
  8. from typing import Generator, Optional, Tuple, Type, Union
  9. import torch
  10. from loguru import logger
  11. from torch import nn
  12. from transformers import PretrainedConfig
  13. from aphrodite.common.config import ModelConfig, ParallelConfig
  14. from aphrodite.modeling.layers.linear import LinearMethodBase
  15. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  16. VocabParallelEmbedding
  17. tensorizer_load_fail = None
  18. try:
  19. from tensorizer import (DecryptionParams, EncryptionParams,
  20. TensorDeserializer, TensorSerializer)
  21. from tensorizer.stream_io import open_stream
  22. from tensorizer.utils import (convert_bytes, get_mem_usage,
  23. no_init_or_tensor)
  24. except ImportError as e:
  25. tensorizer_load_fail = e
  26. __all__ = [
  27. 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
  28. 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
  29. 'no_init_or_tensor', 'TensorizerConfig'
  30. ]
  31. @dataclass
  32. class TensorizerConfig:
  33. tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
  34. str, bytes, os.PathLike, int]
  35. aphrodite_tensorized: bool
  36. verify_hash: Optional[bool] = False
  37. num_readers: Optional[int] = 1
  38. encryption_keyfile: Optional[str] = None
  39. s3_access_key_id: Optional[str] = None
  40. s3_secret_access_key: Optional[str] = None
  41. s3_endpoint: Optional[str] = None
  42. model_class: Optional[Type[torch.nn.Module]] = None
  43. hf_config: Optional[PretrainedConfig] = None
  44. dtype: Optional[Union[str, torch.dtype]] = None
  45. def _construct_tensorizer_args(self) -> "TensorizerArgs":
  46. tensorizer_args = {
  47. "tensorizer_uri": self.tensorizer_uri,
  48. "aphrodite_tensorized": self.aphrodite_tensorized,
  49. "verify_hash": self.verify_hash,
  50. "num_readers": self.num_readers,
  51. "encryption_keyfile": self.encryption_keyfile,
  52. "s3_access_key_id": self.s3_access_key_id,
  53. "s3_secret_access_key": self.s3_secret_access_key,
  54. "s3_endpoint": self.s3_endpoint,
  55. }
  56. return TensorizerArgs(**tensorizer_args)
  57. def verify_with_parallel_config(
  58. self,
  59. parallel_config: "ParallelConfig",
  60. ) -> None:
  61. if (parallel_config.tensor_parallel_size > 1
  62. and self.tensorizer_uri is not None):
  63. raise ValueError(
  64. "Loading to multiple GPUs is not currently supported with "
  65. "aphrodite-serialized models. Please set "
  66. "tensor_parallel_size=1. or use a non-aphrodite-serialized "
  67. "model, such as a serialized Hugging Face `PretrainedModel`.")
  68. def verify_with_model_config(self, model_config: "ModelConfig") -> None:
  69. if (model_config.quantization is not None
  70. and self.tensorizer_uri is not None):
  71. logger.warning(
  72. "Loading a model using Tensorizer with quantization on "
  73. "aphrodite is unstable and may lead to errors.")
  74. def load_with_tensorizer(tensorizer_config: TensorizerConfig,
  75. **extra_kwargs) -> nn.Module:
  76. tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
  77. return tensorizer.deserialize()
  78. def is_aphrodite_serialized_tensorizer(
  79. tensorizer_config: TensorizerConfig) -> bool:
  80. if tensorizer_config is None:
  81. return False
  82. return tensorizer_config.aphrodite_tensorized
  83. @dataclass
  84. class TensorizerArgs:
  85. tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
  86. str, bytes, os.PathLike, int]
  87. aphrodite_tensorized: bool
  88. verify_hash: Optional[bool] = False
  89. num_readers: Optional[int] = 1
  90. encryption_keyfile: Optional[str] = None
  91. s3_access_key_id: Optional[str] = None
  92. s3_secret_access_key: Optional[str] = None
  93. s3_endpoint: Optional[str] = None
  94. """
  95. Args for the TensorizerAgent class. These are used to configure the behavior
  96. of the TensorDeserializer when loading tensors from a serialized model.
  97. Args:
  98. tensorizer_uri: Path to serialized model tensors. Can be a local file
  99. path or a S3 URI.
  100. aphrodite_tensorized: If True, indicates that the serialized model is a
  101. aphrodite model. This is used to determine the behavior of the
  102. TensorDeserializer when loading tensors from a serialized model.
  103. It is far faster to deserialize a aphrodite model as it utilizes
  104. tensorizer's optimized GPU loading.
  105. verify_hash: If True, the hashes of each tensor will be verified against
  106. the hashes stored in the metadata. A `HashMismatchError` will be
  107. raised if any of the hashes do not match.
  108. num_readers: Controls how many threads are allowed to read concurrently
  109. from the source file. Default is 1. This greatly increases
  110. performance.
  111. encryption_keyfile: File path to a binary file containing a
  112. binary key to use for decryption. `None` (the default) means
  113. no decryption. See the example script in
  114. examples/tensorize_aphrodite_model.py.
  115. s3_access_key_id: The access key for the S3 bucket. Can also be set via
  116. the S3_ACCESS_KEY_ID environment variable.
  117. s3_secret_access_key: The secret access key for the S3 bucket. Can also
  118. be set via the S3_SECRET_ACCESS_KEY environment variable.
  119. s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
  120. S3_ENDPOINT_URL environment variable.
  121. """
  122. def __post_init__(self):
  123. self.file_obj = self.tensorizer_uri
  124. self.s3_access_key_id = (self.s3_access_key_id
  125. or os.environ.get("S3_ACCESS_KEY_ID")) or None
  126. self.s3_secret_access_key = (
  127. self.s3_secret_access_key
  128. or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
  129. self.s3_endpoint = (self.s3_endpoint
  130. or os.environ.get("S3_ENDPOINT_URL")) or None
  131. self.stream_params = {
  132. "s3_access_key_id": self.s3_access_key_id,
  133. "s3_secret_access_key": self.s3_secret_access_key,
  134. "s3_endpoint": self.s3_endpoint,
  135. }
  136. self.deserializer_params = {
  137. "verify_hash": self.verify_hash,
  138. "encryption": self.encryption_keyfile,
  139. "num_readers": self.num_readers
  140. }
  141. if self.encryption_keyfile:
  142. with open_stream(
  143. self.encryption_keyfile,
  144. **self.stream_params,
  145. ) as stream:
  146. key = stream.read()
  147. decryption_params = DecryptionParams.from_key(key)
  148. self.deserializer_params['encryption'] = decryption_params
  149. @staticmethod
  150. def add_cli_args(
  151. parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
  152. """Tensorizer CLI arguments"""
  153. # Tensorizer options arg group
  154. group = parser.add_argument_group(
  155. 'tensorizer options',
  156. description=('Options for configuring the behavior of the'
  157. ' tensorizer deserializer when '
  158. '--load-format=tensorizer'))
  159. group.add_argument(
  160. "--tensorizer-uri",
  161. help="Path to serialized model tensors. Can be a local file path,"
  162. " or an HTTP(S) or S3 URI.",
  163. )
  164. group.add_argument(
  165. "--verify-hash",
  166. action="store_true",
  167. help="If enabled, the hashes of each tensor will be verified"
  168. " against the hashes stored in the file metadata. An exception"
  169. " will be raised if any of the hashes do not match.",
  170. )
  171. group.add_argument(
  172. "--encryption-keyfile",
  173. default=None,
  174. help="The file path to a binary file containing a binary key to "
  175. "use for decryption. Can be a file path or S3 network URI.")
  176. group.add_argument(
  177. "--num-readers",
  178. default=1,
  179. type=int,
  180. help="Controls how many threads are allowed to read concurrently "
  181. "from the source file.")
  182. group.add_argument(
  183. "--s3-access-key-id",
  184. default=None,
  185. help="The access key for the S3 bucket. Can also be set via the "
  186. "S3_ACCESS_KEY_ID environment variable.",
  187. )
  188. group.add_argument(
  189. "--s3-secret-access-key",
  190. default=None,
  191. help="The secret access key for the S3 bucket. Can also be set via "
  192. "the S3_SECRET_ACCESS_KEY environment variable.",
  193. )
  194. group.add_argument(
  195. "--s3-endpoint",
  196. default=None,
  197. help="The endpoint for the S3 bucket. Can also be set via the "
  198. "S3_ENDPOINT_URL environment variable.",
  199. )
  200. group.add_argument(
  201. "--aphrodite-tensorized",
  202. action="store_true",
  203. help=
  204. "If enabled, indicates that the serialized model is a aphrodite "
  205. "model. This is used to determine the behavior of the "
  206. "TensorDeserializer when loading tensors from a "
  207. "serialized model.")
  208. return parser
  209. @classmethod
  210. def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
  211. attrs = [attr.name for attr in dataclasses.fields(cls)]
  212. tensorizer_args = cls(**{
  213. attr: getattr(args, attr)
  214. for attr in attrs if hasattr(args, attr)
  215. })
  216. return tensorizer_args
  217. class TensorizerAgent:
  218. """
  219. A class for performing tensorizer deserializations specifically for
  220. aphrodite models using plaid_mode. Uses TensorizerArgs to configure the
  221. behavior of the TensorDeserializer when loading tensors from a serialized
  222. model. For deserializations of HuggingFace models, TensorDeserializer is
  223. instead used as an iterator directly in the func hf_model_weights_iterator
  224. in aphrodite/modeling/model_loader/weight_utils.py
  225. """
  226. def __init__(self, tensorizer_config: TensorizerConfig,
  227. linear_method: LinearMethodBase, **extra_kwargs):
  228. if tensorizer_load_fail is not None:
  229. raise ImportError(
  230. "Tensorizer is not installed. Please install tensorizer "
  231. "to use this feature with "
  232. "`pip install aphrodite-engine[tensorizer]`."
  233. ) from tensorizer_load_fail
  234. self.tensorizer_config = tensorizer_config
  235. self.tensorizer_args = (
  236. self.tensorizer_config._construct_tensorizer_args())
  237. self.extra_kwargs = extra_kwargs
  238. if extra_kwargs.get("linear_method", None) is not None:
  239. self.linear_method = extra_kwargs["linear_method"]
  240. else:
  241. self.linear_method = linear_method
  242. self.model = self._init_model()
  243. def _init_model(self):
  244. model_args = self.tensorizer_config.hf_config
  245. model_args.torch_dtype = self.tensorizer_config.dtype
  246. with no_init_or_tensor():
  247. return self.tensorizer_config.model_class(
  248. config=model_args,
  249. linear_method=self.linear_method,
  250. **self.extra_kwargs)
  251. def _resize_lora_embeddings(self):
  252. """Modify LoRA embedding layers to use bigger tensors
  253. to allow for adapter added tokens."""
  254. for child in self.model.modules():
  255. if (isinstance(child, VocabParallelEmbedding)
  256. and child.weight.shape[0] <
  257. child.num_embeddings_per_partition):
  258. new_weight = torch.empty(child.num_embeddings_per_partition,
  259. child.embedding_dim,
  260. dtype=child.weight.dtype,
  261. device=child.weight.device)
  262. new_weight[:child.weight.shape[0]].copy_(child.weight.data)
  263. new_weight[child.weight.shape[0]:].fill_(0)
  264. child.weight.data = new_weight
  265. def _check_tensors_on_meta_device(self):
  266. for tensor in self.model.state_dict().values():
  267. if tensor.device.type == 'meta':
  268. raise ValueError(
  269. "The serialized model contains tensors on the meta device,"
  270. " indicating that some tensors were not loaded properly."
  271. " Please check that the parameters of the model being"
  272. " specified match that of the serialized model, such as"
  273. " its quantization.")
  274. def deserialize(self):
  275. """
  276. Deserialize the model using the TensorDeserializer. This method is
  277. specifically for Aphrodite models using tensorizer's plaid_mode.
  278. The deserializer makes use of tensorizer_args.stream_params
  279. to configure the behavior of the stream when loading tensors from a
  280. serialized model. The deserializer_params are used to configure the
  281. behavior of the TensorDeserializer when loading tensors themselves.
  282. Documentation on these params can be found in TensorizerArgs
  283. Returns:
  284. nn.Module: The deserialized model.
  285. """
  286. before_mem = get_mem_usage()
  287. start = time.perf_counter()
  288. with open_stream(
  289. self.tensorizer_args.tensorizer_uri,
  290. mode="rb",
  291. **self.tensorizer_args.stream_params,
  292. ) as stream, TensorDeserializer(
  293. stream,
  294. dtype=self.tensorizer_config.dtype,
  295. **self.tensorizer_args.deserializer_params) as deserializer:
  296. deserializer.load_into_module(self.model)
  297. end = time.perf_counter()
  298. total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
  299. duration = end - start
  300. per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
  301. after_mem = get_mem_usage()
  302. deserializer.close()
  303. logger.info(f"Deserialized {total_bytes_str} in "
  304. f"{end - start:0.2f}s, {per_second}/s")
  305. logger.info(f"Memory usage before: {before_mem}")
  306. logger.info(f"Memory usage after: {after_mem}")
  307. self._check_tensors_on_meta_device()
  308. self._resize_lora_embeddings()
  309. return self.model.eval()
  310. def tensorizer_weights_iterator(
  311. tensorizer_args: "TensorizerArgs"
  312. ) -> Generator[Tuple[str, torch.Tensor], None, None]:
  313. logger.warning(
  314. "Deserializing HuggingFace models is not optimized for "
  315. "loading on Aphrodite, as tensorizer is forced to load to CPU. "
  316. "Consider deserializing a Aphrodite model instead for faster "
  317. "load times. See the examples/tensorize_aphrodite_model.py example "
  318. "script for serializing Aphrodite models.")
  319. deserializer_args = tensorizer_args.deserializer_params
  320. stream_params = tensorizer_args.stream_params
  321. stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
  322. with TensorDeserializer(stream, **deserializer_args,
  323. device="cpu") as state:
  324. for name, param in state.items():
  325. yield name, param
  326. del state