config.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. from typing import Optional
  2. import torch
  3. from transformers import PretrainedConfig
  4. from transformers.utils.quantization_config import QuantizationMethod
  5. from aphrodite.common.logger import init_logger
  6. from aphrodite.transformers_utils.config import get_config
  7. from aphrodite.common.utils import get_cpu_memory
  8. logger = init_logger(__name__)
  9. _GB = 1 << 30
  10. class ModelConfig:
  11. """Configuration for the model.
  12. Args:
  13. model: Name or path of the huggingface model to use.
  14. tokenizer: Name or path of the huggingface tokenizer to use.
  15. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  16. available, and "slow" will always use the slow tokenizer.
  17. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  18. downloading the model and tokenizer.
  19. download_dir: Directory to download and load the weights, default to the
  20. default cache directory of huggingface.
  21. load_format: The format of the model weights to load:
  22. "auto" will try to load the weights in the safetensors format and
  23. fall back to the pytorch bin format if safetensors format is
  24. not available.
  25. "pt" will load the weights in the pytorch bin format.
  26. "safetensors" will load the weights in the safetensors format.
  27. "npcache" will load the weights in pytorch format and store
  28. a numpy cache to speed up the loading.
  29. "dummy" will initialize the weights with random values, which is
  30. mainly for profiling.
  31. dtype: Data type for model weights and activations. The "auto" option
  32. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  33. for BF16 models.
  34. seed: Random seed for reproducibility.
  35. revision: The specific model version to use. It can be a branch name,
  36. a tag name, or a commit id. If unspecified, will use the default
  37. version.
  38. max_model_len: Maximum length of a sequence (including prompt and
  39. output). If None, will be derived from the model.
  40. quantization: Quantization method that was used to quantize the model
  41. weights. If None, we assume the model weights are not quantized.
  42. """
  43. def __init__(
  44. self,
  45. model: str,
  46. tokenizer: str,
  47. tokenizer_mode: str,
  48. trust_remote_code: bool,
  49. download_dir: Optional[str],
  50. load_format: str,
  51. dtype: str,
  52. seed: int,
  53. revision: Optional[str] = None,
  54. max_model_len: Optional[int] = None,
  55. quantization: Optional[str] = None,
  56. ) -> None:
  57. self.model = model
  58. self.tokenizer = tokenizer
  59. self.tokenizer_mode = tokenizer_mode
  60. self.trust_remote_code = trust_remote_code
  61. self.download_dir = download_dir
  62. self.load_format = load_format
  63. self.seed = seed
  64. self.revision = revision
  65. self.quantization = quantization
  66. self.hf_config = get_config(model, trust_remote_code, revision)
  67. self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
  68. self.max_model_len = _get_and_verify_max_len(self.hf_config,
  69. max_model_len)
  70. self._verify_load_format()
  71. self._verify_tokenizer_mode()
  72. self._verify_quantization()
  73. def _verify_load_format(self) -> None:
  74. load_format = self.load_format.lower()
  75. if load_format not in [
  76. "auto", "pt", "safetensors", "npcache", "dummy"
  77. ]:
  78. raise ValueError(
  79. f"Unknown load format: {self.load_format}. Must be one of "
  80. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  81. self.load_format = load_format
  82. def _verify_tokenizer_mode(self) -> None:
  83. tokenizer_mode = self.tokenizer_mode.lower()
  84. if tokenizer_mode not in ["auto", "slow"]:
  85. raise ValueError(
  86. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  87. "either 'auto' or 'slow'.")
  88. self.tokenizer_mode = tokenizer_mode
  89. def _verify_quantization(self) -> None:
  90. supported_quantization = ["awq", "gptq"]
  91. if hasattr(self.hf_config, "quantization_config"
  92. ) and self.hf_config.quantization_config.get(
  93. "quant_method") == QuantizationMethod.GPTQ:
  94. self.quantization = "gptq"
  95. if self.quantization is None:
  96. return
  97. quantization = self.quantization.lower()
  98. if quantization not in supported_quantization:
  99. raise ValueError(
  100. f"Unknown quantization: {self.quantization}. Must be one of "
  101. f"{supported_quantization}.")
  102. self.quantization = quantization
  103. def verify_with_parallel_config(
  104. self,
  105. parallel_config: "ParallelConfig",
  106. ) -> None:
  107. total_num_attention_heads = self.hf_config.num_attention_heads
  108. tensor_parallel_size = parallel_config.tensor_parallel_size
  109. if total_num_attention_heads % tensor_parallel_size != 0:
  110. raise ValueError(
  111. f"Total number of attention heads ({total_num_attention_heads})"
  112. " must be divisible by tensor parallel size "
  113. f"({tensor_parallel_size}).")
  114. total_num_hidden_layers = self.hf_config.num_hidden_layers
  115. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  116. if total_num_hidden_layers % pipeline_parallel_size != 0:
  117. raise ValueError(
  118. f"Total number of hidden layers ({total_num_hidden_layers}) "
  119. "must be divisible by pipeline parallel size "
  120. f"({pipeline_parallel_size}).")
  121. def get_hidden_size(self) -> int:
  122. return self.hf_config.hidden_size
  123. def get_head_size(self) -> int:
  124. # FIXME(woosuk): This may not be true for all models.
  125. return self.hf_config.hidden_size // self.hf_config.num_attention_heads
  126. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  127. """Returns the number of KV heads per GPU worker."""
  128. if getattr(self.hf_config, "n_head_kv", None) is not None:
  129. return (self.hf_config.n_head_kv //
  130. parallel_config.tensor_parallel_size)
  131. if getattr(self.hf_config, "num_kv_heads", None) is not None:
  132. return (self.hf_config.num_kv_heads //
  133. parallel_config.tensor_parallel_size)
  134. # For LLaMA-2:
  135. if getattr(self.hf_config, "num_key_value_heads", None) is not None:
  136. return (self.hf_config.num_key_value_heads //
  137. parallel_config.tensor_parallel_size)
  138. total_num_attention_heads = self.hf_config.num_attention_heads
  139. return total_num_attention_heads // parallel_config.tensor_parallel_size
  140. def get_max_model_len(self) -> int:
  141. return self.max_model_len
  142. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  143. total_num_hidden_layers = self.hf_config.num_hidden_layers
  144. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  145. class CacheConfig:
  146. """Configuration for the KV cache.
  147. Args:
  148. block_size: Size of a cache block in number of tokens.
  149. gpu_memory_utilization: Fraction of GPU memory to use for the
  150. Aphrodite execution.
  151. swap_space: Size of the CPU swap space per GPU (in GiB).
  152. """
  153. def __init__(
  154. self,
  155. block_size: int,
  156. gpu_memory_utilization: float,
  157. swap_space: int,
  158. sliding_window: Optional[int] = None,
  159. ) -> None:
  160. self.block_size = block_size
  161. self.gpu_memory_utilization = gpu_memory_utilization
  162. self.swap_space_bytes = swap_space * _GB
  163. self.sliding_window = sliding_window
  164. self._verify_args()
  165. # Will be set after profiling.
  166. self.num_gpu_blocks = None
  167. self.num_cpu_blocks = None
  168. def _verify_args(self) -> None:
  169. if self.gpu_memory_utilization > 1.0:
  170. raise ValueError(
  171. "GPU memory utilization must be less than 1.0. Got "
  172. f"{self.gpu_memory_utilization}.")
  173. def verify_with_parallel_config(
  174. self,
  175. parallel_config: "ParallelConfig",
  176. ) -> None:
  177. total_cpu_memory = get_cpu_memory()
  178. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  179. # group are in the same node. However, the GPUs may span multiple nodes.
  180. num_gpus_per_node = parallel_config.tensor_parallel_size
  181. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  182. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  183. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  184. "allocated for the swap space.")
  185. if cpu_memory_usage > 0.7 * total_cpu_memory:
  186. raise ValueError("Too large swap space. " + msg)
  187. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  188. logger.warning("Possibly too large swap space. " + msg)
  189. class ParallelConfig:
  190. """Configuration for the distributed execution.
  191. Args:
  192. pipeline_parallel_size: Number of pipeline parallel groups.
  193. tensor_parallel_size: Number of tensor parallel groups.
  194. worker_use_ray: Whether to use Ray for model workers. Will be set to
  195. True if either pipeline_parallel_size or tensor_parallel_size is
  196. greater than 1.
  197. """
  198. def __init__(
  199. self,
  200. pipeline_parallel_size: int,
  201. tensor_parallel_size: int,
  202. worker_use_ray: bool,
  203. ) -> None:
  204. self.pipeline_parallel_size = pipeline_parallel_size
  205. self.tensor_parallel_size = tensor_parallel_size
  206. self.worker_use_ray = worker_use_ray
  207. self.world_size = pipeline_parallel_size * tensor_parallel_size
  208. if self.world_size > 1:
  209. self.worker_use_ray = True
  210. self._verify_args()
  211. def _verify_args(self) -> None:
  212. if self.pipeline_parallel_size > 1:
  213. raise NotImplementedError(
  214. "Pipeline parallelism is not supported yet.")
  215. class SchedulerConfig:
  216. """Scheduler configuration.
  217. Args:
  218. max_num_batched_tokens: Maximum number of tokens to be processed in
  219. a single iteration.
  220. max_num_seqs: Maximum number of sequences to be processed in a single
  221. iteration.
  222. max_model_len: Maximum length of a sequence (including prompt
  223. and generated text).
  224. max_paddings: Maximum number of paddings to be added to a batch.
  225. """
  226. def __init__(
  227. self,
  228. max_num_batched_tokens: Optional[int],
  229. max_num_seqs: int,
  230. max_model_len: int,
  231. max_paddings: int,
  232. ) -> None:
  233. if max_num_batched_tokens is not None:
  234. self.max_num_batched_tokens = max_num_batched_tokens
  235. else:
  236. self.max_num_batched_tokens = max(max_model_len, 2048)
  237. self.max_num_seqs = max_num_seqs
  238. self.max_model_len = max_model_len
  239. self.max_paddings = max_paddings
  240. self._verify_args()
  241. def _verify_args(self) -> None:
  242. if self.max_num_batched_tokens < self.max_model_len:
  243. raise ValueError(
  244. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  245. f"smaller than max_model_len ({self.max_model_len}). "
  246. f"This effectively limits the maximum sequence length to "
  247. f"max_num_batched_tokens and makes Aphrodite reject longer "
  248. f"sequences. Please increase max_num_batched_tokens or "
  249. f"decrease max_model_len.")
  250. if self.max_num_batched_tokens < self.max_num_seqs:
  251. raise ValueError(
  252. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  253. "be greater than or equal to max_num_seqs "
  254. f"({self.max_num_seqs}).")
  255. _STR_DTYPE_TO_TORCH_DTYPE = {
  256. "half": torch.float16,
  257. "float16": torch.float16,
  258. "float": torch.float32,
  259. "float32": torch.float32,
  260. "bfloat16": torch.bfloat16,
  261. }
  262. def _get_and_verify_dtype(
  263. config: PretrainedConfig,
  264. dtype: str,
  265. ) -> torch.dtype:
  266. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  267. # because config.torch_dtype can be None.
  268. config_dtype = getattr(config, "torch_dtype", None)
  269. if config_dtype is None:
  270. config_dtype = torch.float32
  271. dtype = dtype.lower()
  272. if dtype == "auto":
  273. if config_dtype == torch.float32:
  274. # Following the common practice, we use float16 for float32 models.
  275. torch_dtype = torch.float16
  276. else:
  277. torch_dtype = config_dtype
  278. else:
  279. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  280. raise ValueError(f"Unknown dtype: {dtype}")
  281. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  282. # Verify the dtype.
  283. if torch_dtype != config_dtype:
  284. if torch_dtype == torch.float32:
  285. # Upcasting to float32 is allowed.
  286. pass
  287. elif config_dtype == torch.float32:
  288. # Downcasting from float32 to float16 or bfloat16 is allowed.
  289. pass
  290. else:
  291. # Casting between float16 and bfloat16 is allowed with a warning.
  292. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  293. return torch_dtype
  294. def _get_and_verify_max_len(
  295. hf_config: PretrainedConfig,
  296. max_model_len: Optional[int],
  297. ) -> int:
  298. """Get and verify the model's maximum length."""
  299. derived_max_model_len = float("inf")
  300. possible_keys = [
  301. "max_position_embeddings",
  302. "n_positions",
  303. "max_seq_len",
  304. "max_sequence_length",
  305. "max_seq_length",
  306. "seq_len",
  307. ]
  308. for key in possible_keys:
  309. max_len_key = getattr(hf_config, key, None)
  310. if max_len_key is not None:
  311. derived_max_model_len = min(derived_max_model_len, max_len_key)
  312. if derived_max_model_len == float("inf"):
  313. raise ValueError(
  314. "The model's config.json must contain one of the following keys "
  315. "to determine the original maximum length of the model: "
  316. f"{possible_keys}")
  317. rope_scaling = getattr(hf_config, "rope_scaling", None)
  318. if rope_scaling is not None:
  319. assert "factor" in rope_scaling
  320. scaling_factor = rope_scaling["factor"]
  321. derived_max_model_len *= scaling_factor
  322. if max_model_len is None:
  323. max_model_len = derived_max_model_len
  324. elif max_model_len > derived_max_model_len:
  325. raise ValueError(
  326. f"User-specified max_model_len ({max_model_len}) is greater than "
  327. f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
  328. " in model's config.json). This may lead to incorrect model "
  329. "outputs or CUDA errors. Make sure the value is correct and "
  330. "within the model context size.")
  331. return int(max_model_len)