config.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from typing import Optional, Union
  2. import torch
  3. from transformers import PretrainedConfig
  4. from aphrodite.common.logger import init_logger
  5. from aphrodite.transformers_utils.config import get_config
  6. from aphrodite.common.utils import get_cpu_memory, is_hip
  7. logger = init_logger(__name__)
  8. _GB = 1 << 30
  9. class ModelConfig:
  10. """Configuration for the model.
  11. Args:
  12. model: Name or path of the huggingface model to use.
  13. tokenizer: Name or path of the huggingface tokenizer to use.
  14. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  15. available, and "slow" will always use the slow tokenizer.
  16. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  17. downloading the model and tokenizer.
  18. download_dir: Directory to download and load the weights, default to the
  19. default cache directory of huggingface.
  20. load_format: The format of the model weights to load:
  21. "auto" will try to load the weights in the safetensors format and
  22. fall back to the pytorch bin format if safetensors format is
  23. not available.
  24. "pt" will load the weights in the pytorch bin format.
  25. "safetensors" will load the weights in the safetensors format.
  26. "npcache" will load the weights in pytorch format and store
  27. a numpy cache to speed up the loading.
  28. "dummy" will initialize the weights with random values, which is
  29. mainly for profiling.
  30. dtype: Data type for model weights and activations. The "auto" option
  31. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  32. for BF16 models.
  33. seed: Random seed for reproducibility.
  34. revision: The specific model version to use. It can be a branch name,
  35. a tag name, or a commit id. If unspecified, will use the default
  36. version.
  37. max_model_len: Maximum length of a sequence (including prompt and
  38. output). If None, will be derived from the model.
  39. quantization: Quantization method that was used to quantize the model
  40. weights. If None, we assume the model weights are not quantized
  41. enforce_eager: Whether to enforce eager execution. If True, we will
  42. disable CUDA graph and always execute the model in eager mode.
  43. If False, we will use CUDA graph and eager execution in hybrid.
  44. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  45. When a sequence has context length larger than this, we will fall
  46. back to eager mode.
  47. """
  48. def __init__(
  49. self,
  50. model: str,
  51. tokenizer: str,
  52. tokenizer_mode: str,
  53. trust_remote_code: bool,
  54. download_dir: Optional[str],
  55. load_format: str,
  56. dtype: Union[str, torch.dtype],
  57. seed: int,
  58. revision: Optional[str] = None,
  59. max_model_len: Optional[int] = None,
  60. quantization: Optional[str] = None,
  61. enforce_eager: bool = False,
  62. max_context_len_to_capture: Optional[int] = None,
  63. ) -> None:
  64. self.model = model
  65. self.tokenizer = tokenizer
  66. self.tokenizer_mode = tokenizer_mode
  67. self.trust_remote_code = trust_remote_code
  68. self.download_dir = download_dir
  69. self.load_format = load_format
  70. self.seed = seed
  71. self.revision = revision
  72. self.quantization = quantization
  73. self.enforce_eager = enforce_eager
  74. self.max_context_len_to_capture = max_context_len_to_capture
  75. self.hf_config = get_config(model, trust_remote_code, revision)
  76. self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
  77. self.max_model_len = _get_and_verify_max_len(self.hf_config,
  78. max_model_len)
  79. self._verify_load_format()
  80. self._verify_tokenizer_mode()
  81. self._verify_quantization()
  82. self._verify_cuda_graph()
  83. def _verify_load_format(self) -> None:
  84. load_format = self.load_format.lower()
  85. supported_load_format = [
  86. "auto", "pt", "safetensors", "npcache", "dummy"
  87. ]
  88. rocm_not_supported_load_format = []
  89. if load_format not in supported_load_format:
  90. raise ValueError(
  91. f"Unknown load format: {self.load_format}. Must be one of "
  92. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  93. if is_hip() and load_format in rocm_not_supported_load_format:
  94. rocm_supported_load_format = [
  95. f for f in supported_load_format
  96. if (f not in rocm_not_supported_load_format)
  97. ]
  98. raise ValueError(
  99. f"load format \'{load_format}\' is not supported in ROCm. "
  100. f"Supported load format are "
  101. f"{rocm_supported_load_format}")
  102. # TODO: Remove this check once HF updates the pt weights of Mixtral.
  103. architectures = getattr(self.hf_config, "architectures", [])
  104. if "MixtralForCausalLM" in architectures and load_format == "pt":
  105. raise ValueError(
  106. "Currently, the 'pt' format is not supported for Mixtral. "
  107. "Please use the 'safetensors' format instead. ")
  108. self.load_format = load_format
  109. def _verify_tokenizer_mode(self) -> None:
  110. tokenizer_mode = self.tokenizer_mode.lower()
  111. if tokenizer_mode not in ["auto", "slow"]:
  112. raise ValueError(
  113. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  114. "either 'auto' or 'slow'.")
  115. self.tokenizer_mode = tokenizer_mode
  116. def _verify_quantization(self) -> None:
  117. supported_quantization = ["awq", "squeezellm", "gptq"]
  118. rocm_not_supported_quantization = ["awq"]
  119. if self.quantization is not None:
  120. self.quantization = self.quantization.lower()
  121. hf_quant_config = getattr(self.hf_config, "quant_config", None)
  122. if hf_quant_config is not None:
  123. hf_quant_method = str(hf_quant_config["quant_method"]).lower()
  124. if self.quantization is None:
  125. self.quantization = hf_quant_method
  126. elif self.quantization != hf_quant_method:
  127. raise ValueError(
  128. f"Model quantization method is {hf_quant_method} "
  129. f"but quantization argument is {self.quantization}. "
  130. "Please use the same quantization method.")
  131. if self.quantization is not None:
  132. if self.quantization not in supported_quantization:
  133. raise ValueError(
  134. f"Unknown quantization method: {self.quantization}. "
  135. f"Must be one of {supported_quantization}.")
  136. if is_hip(
  137. ) and self.quantization in rocm_not_supported_quantization:
  138. raise ValueError(
  139. f"{self.quantization} quantization method is currently "
  140. "not supported in ROCm.")
  141. if self.quantization is not None:
  142. logger.warning(f"{self.quantization} quantization is not fully "
  143. "optimized yet. The speed can be slower than "
  144. "non-quantized models (16/32bit).")
  145. def _verify_cuda_graph(self) -> None:
  146. if self.max_context_len_to_capture is None:
  147. self.max_context_len_to_capture = self.max_model_len
  148. self.max_context_len_to_capture = min(self.max_context_len_to_capture,
  149. self.max_model_len)
  150. def verify_with_parallel_config(
  151. self,
  152. parallel_config: "ParallelConfig",
  153. ) -> None:
  154. total_num_attention_heads = self.hf_config.num_attention_heads
  155. tensor_parallel_size = parallel_config.tensor_parallel_size
  156. if total_num_attention_heads % tensor_parallel_size != 0:
  157. raise ValueError(
  158. f"Total number of attention heads ({total_num_attention_heads})"
  159. " must be divisible by tensor parallel size "
  160. f"({tensor_parallel_size}).")
  161. total_num_hidden_layers = self.hf_config.num_hidden_layers
  162. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  163. if total_num_hidden_layers % pipeline_parallel_size != 0:
  164. raise ValueError(
  165. f"Total number of hidden layers ({total_num_hidden_layers}) "
  166. "must be divisible by pipeline parallel size "
  167. f"({pipeline_parallel_size}).")
  168. def get_sliding_window(self) -> Optional[int]:
  169. return getattr(self.hf_config, "sliding_window", None)
  170. def get_vocab_size(self) -> int:
  171. return self.hf_config.vocab_size
  172. def get_hidden_size(self) -> int:
  173. return self.hf_config.hidden_size
  174. def get_head_size(self) -> int:
  175. # FIXME: This may not be true for all models.
  176. return self.hf_config.hidden_size // self.hf_config.num_attention_heads
  177. def get_total_num_kv_heads(self) -> int:
  178. """Returns the total number of KV heads."""
  179. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  180. new_decoder_arch_falcon = (
  181. self.hf_config.model_type in falcon_model_types
  182. and getattr(self.hf_config, "new_decoder_architecture", False))
  183. if not new_decoder_arch_falcon and getattr(self.hf_config,
  184. "multi_query", False):
  185. # Multi-query attention, only one KV head.
  186. # Currently, tensor parallelism is not supported in this case.
  187. return 1
  188. attributes = [
  189. "n_head_kv",
  190. "num_kv_heads",
  191. "num_key_value_heads",
  192. "multi_query_group_num",
  193. ]
  194. for attr in attributes:
  195. num_kv_heads = getattr(self.hf_config, attr, None)
  196. if num_kv_heads is not None:
  197. return num_kv_heads
  198. # For non-grouped-query attention models, the number of KV heads is
  199. # equal to the number of attention heads.
  200. return self.hf_config.num_attention_heads
  201. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  202. """Returns the number of KV heads per GPU."""
  203. total_num_kv_heads = self.get_total_num_kv_heads()
  204. # If tensor parallelism is used, we divide the number of KV heads by
  205. # the tensor parallel size. We will replicate the KV heads in the
  206. # case where the number of KV heads is smaller than the tensor
  207. # parallel size so each GPU has at least one KV head.
  208. return max(1,
  209. total_num_kv_heads // parallel_config.tensor_parallel_size)
  210. def get_max_model_len(self) -> int:
  211. return self.max_model_len
  212. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  213. total_num_hidden_layers = self.hf_config.num_hidden_layers
  214. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  215. class CacheConfig:
  216. """Configuration for the KV cache.
  217. Args:
  218. block_size: Size of a cache block in number of tokens.
  219. gpu_memory_utilization: Fraction of GPU memory to use for the
  220. Aphrodite execution.
  221. swap_space: Size of the CPU swap space per GPU (in GiB).
  222. cache_dtype: Data type fro the KV cache.
  223. """
  224. def __init__(
  225. self,
  226. block_size: int,
  227. gpu_memory_utilization: float,
  228. swap_space: int,
  229. cache_dtype: str,
  230. sliding_window: Optional[int] = None,
  231. ) -> None:
  232. self.block_size = block_size
  233. self.gpu_memory_utilization = gpu_memory_utilization
  234. self.swap_space_bytes = swap_space * _GB
  235. self.cache_dtype = cache_dtype
  236. if cache_dtype and "fp8" in cache_dtype.lower():
  237. # As FP8 is not a formal data type, we use
  238. # torch.uint8 instead.
  239. self.cache_dtype = torch.uint8
  240. self.sliding_window = sliding_window
  241. self._verify_args()
  242. # Will be set after profiling.
  243. self.num_gpu_blocks = None
  244. self.num_cpu_blocks = None
  245. def _verify_args(self) -> None:
  246. if self.gpu_memory_utilization > 1.0:
  247. raise ValueError(
  248. "GPU memory utilization must be less than 1.0. Got "
  249. f"{self.gpu_memory_utilization}.")
  250. def verify_with_parallel_config(
  251. self,
  252. parallel_config: "ParallelConfig",
  253. ) -> None:
  254. total_cpu_memory = get_cpu_memory()
  255. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  256. # group are in the same node. However, the GPUs may span multiple nodes.
  257. num_gpus_per_node = parallel_config.tensor_parallel_size
  258. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  259. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  260. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  261. "allocated for the swap space.")
  262. if cpu_memory_usage > 0.7 * total_cpu_memory:
  263. raise ValueError("Too large swap space. " + msg)
  264. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  265. logger.warning("Possibly too large swap space. " + msg)
  266. class ParallelConfig:
  267. """Configuration for the distributed execution.
  268. Args:
  269. pipeline_parallel_size: Number of pipeline parallel groups.
  270. tensor_parallel_size: Number of tensor parallel groups.
  271. worker_use_ray: Whether to use Ray for model workers. Will be set to
  272. True if either pipeline_parallel_size or tensor_parallel_size is
  273. greater than 1.
  274. """
  275. def __init__(
  276. self,
  277. pipeline_parallel_size: int,
  278. tensor_parallel_size: int,
  279. worker_use_ray: bool,
  280. max_parallel_loading_workers: Optional[int] = None,
  281. ) -> None:
  282. self.pipeline_parallel_size = pipeline_parallel_size
  283. self.tensor_parallel_size = tensor_parallel_size
  284. self.worker_use_ray = worker_use_ray
  285. self.max_parallel_loading_workers = max_parallel_loading_workers
  286. self.world_size = pipeline_parallel_size * tensor_parallel_size
  287. if self.world_size > 1:
  288. self.worker_use_ray = True
  289. self._verify_args()
  290. def _verify_args(self) -> None:
  291. if self.pipeline_parallel_size > 1:
  292. raise NotImplementedError(
  293. "Pipeline parallelism is not supported yet.")
  294. class SchedulerConfig:
  295. """Scheduler configuration.
  296. Args:
  297. max_num_batched_tokens: Maximum number of tokens to be processed in
  298. a single iteration.
  299. max_num_seqs: Maximum number of sequences to be processed in a single
  300. iteration.
  301. max_model_len: Maximum length of a sequence (including prompt
  302. and generated text).
  303. max_paddings: Maximum number of paddings to be added to a batch.
  304. """
  305. def __init__(
  306. self,
  307. max_num_batched_tokens: Optional[int],
  308. max_num_seqs: int,
  309. max_model_len: int,
  310. max_paddings: int,
  311. ) -> None:
  312. if max_num_batched_tokens is not None:
  313. self.max_num_batched_tokens = max_num_batched_tokens
  314. else:
  315. self.max_num_batched_tokens = max(max_model_len, 2048)
  316. self.max_num_seqs = max_num_seqs
  317. self.max_model_len = max_model_len
  318. self.max_paddings = max_paddings
  319. self._verify_args()
  320. def _verify_args(self) -> None:
  321. if self.max_num_batched_tokens < self.max_model_len:
  322. raise ValueError(
  323. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  324. f"smaller than max_model_len ({self.max_model_len}). "
  325. f"This effectively limits the maximum sequence length to "
  326. f"max_num_batched_tokens and makes Aphrodite reject longer "
  327. f"sequences. Please increase max_num_batched_tokens or "
  328. f"decrease max_model_len.")
  329. if self.max_num_batched_tokens < self.max_num_seqs:
  330. raise ValueError(
  331. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  332. "be greater than or equal to max_num_seqs "
  333. f"({self.max_num_seqs}).")
  334. _STR_DTYPE_TO_TORCH_DTYPE = {
  335. "half": torch.float16,
  336. "float16": torch.float16,
  337. "float": torch.float32,
  338. "float32": torch.float32,
  339. "bfloat16": torch.bfloat16,
  340. }
  341. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  342. def _get_and_verify_dtype(
  343. config: PretrainedConfig,
  344. dtype: Union[str, torch.dtype],
  345. ) -> torch.dtype:
  346. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  347. # because config.torch_dtype can be None.
  348. config_dtype = getattr(config, "torch_dtype", None)
  349. if config_dtype is None:
  350. config_dtype = torch.float32
  351. if isinstance(dtype, str):
  352. dtype = dtype.lower()
  353. if dtype == "auto":
  354. if config_dtype == torch.float32:
  355. torch_dtype = torch.float16
  356. else:
  357. torch_dtype = config_dtype
  358. else:
  359. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  360. raise ValueError(f"Unknown dtype: {dtype}. Must be one of "
  361. f"{list(_STR_DTYPE_TO_TORCH_DTYPE.keys())}.")
  362. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  363. elif isinstance(dtype, torch.dtype):
  364. torch_dtype = dtype
  365. else:
  366. raise ValueError(
  367. f"Unknown dtype: {dtype}. Must be either a string or a torch "
  368. "dtype.")
  369. if is_hip() and torch_dtype == torch.float32:
  370. rocm_supported_dtypes = [
  371. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  372. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  373. ]
  374. raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
  375. f"Supported dtypes are {rocm_supported_dtypes}")
  376. # Verify the dtype.
  377. if torch_dtype != config_dtype:
  378. if torch_dtype == torch.float32:
  379. # Upcasting to float32 is allowed.
  380. pass
  381. elif config_dtype == torch.float32:
  382. # Downcasting from float32 to float16 or bfloat16 is allowed.
  383. pass
  384. else:
  385. # Casting between float16 and bfloat16 is allowed with a warning.
  386. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  387. return torch_dtype
  388. def _get_and_verify_max_len(
  389. hf_config: PretrainedConfig,
  390. max_model_len: Optional[int],
  391. ) -> int:
  392. """Get and verify the model's maximum length."""
  393. derived_max_model_len = float("inf")
  394. possible_keys = [
  395. "max_position_embeddings",
  396. "n_positions",
  397. "max_seq_len",
  398. "max_sequence_length",
  399. "max_seq_length",
  400. "seq_len",
  401. ]
  402. for key in possible_keys:
  403. max_len_key = getattr(hf_config, key, None)
  404. if max_len_key is not None:
  405. derived_max_model_len = min(derived_max_model_len, max_len_key)
  406. if derived_max_model_len == float("inf"):
  407. if max_model_len is not None:
  408. # If max_model_len is specified, we use it.
  409. return max_model_len
  410. default_max_len = 2048
  411. logger.warning(
  412. "The model's config.json does not contain any of the following "
  413. "keys to determine the original maximum length of the model: "
  414. f"{possible_keys}. Assuming the model's maximum length is "
  415. f"{default_max_len}.")
  416. derived_max_model_len = default_max_len
  417. rope_scaling = getattr(hf_config, "rope_scaling", None)
  418. if rope_scaling is not None:
  419. assert "factor" in rope_scaling
  420. scaling_factor = rope_scaling["factor"]
  421. if rope_scaling["type"] == "yarn":
  422. derived_max_model_len = rope_scaling[
  423. "original_max_position_embeddings"]
  424. derived_max_model_len *= scaling_factor
  425. if max_model_len is None:
  426. max_model_len = derived_max_model_len
  427. elif max_model_len > derived_max_model_len:
  428. # hope this works
  429. scaling_factor = max_model_len / derived_max_model_len
  430. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  431. logger.warning(
  432. f"User-specified max_model_len {max_model_len} is higher than "
  433. f"the original {derived_max_model_len}. "
  434. "Attempting to use RoPE scaling.")
  435. derived_max_model_len = max_model_len
  436. return int(max_model_len)