config.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099
  1. import enum
  2. from typing import TYPE_CHECKING, Optional, Union, ClassVar
  3. from dataclasses import dataclass, fields
  4. import os
  5. from packaging.version import Version
  6. from loguru import logger
  7. import json
  8. import torch
  9. from transformers import PretrainedConfig
  10. from aphrodite.transformers_utils.config import get_config, get_hf_text_config
  11. from aphrodite.common.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron,
  12. get_nvcc_cuda_version)
  13. if TYPE_CHECKING:
  14. from ray.util.placement_group import PlacementGroup
  15. _GB = 1 << 30
  16. class ModelConfig:
  17. """Configuration for the model.
  18. Args:
  19. model: Name or path of the huggingface model to use.
  20. tokenizer: Name or path of the huggingface tokenizer to use.
  21. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  22. available, and "slow" will always use the slow tokenizer.
  23. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  24. downloading the model and tokenizer.
  25. download_dir: Directory to download and load the weights, default to the
  26. default cache directory of huggingface.
  27. load_format: The format of the model weights to load:
  28. "auto" will try to load the weights in the safetensors format and
  29. fall back to the pytorch bin format if safetensors format is
  30. not available.
  31. "pt" will load the weights in the pytorch bin format.
  32. "safetensors" will load the weights in the safetensors format.
  33. "npcache" will load the weights in pytorch format and store
  34. a numpy cache to speed up the loading.
  35. "dummy" will initialize the weights with random values, which is
  36. mainly for profiling.
  37. dtype: Data type for model weights and activations. The "auto" option
  38. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  39. for BF16 models.
  40. seed: Random seed for reproducibility.
  41. revision: The specific model version to use. It can be a branch name,
  42. a tag name, or a commit id. If unspecified, will use the default
  43. version.
  44. code_revision: The specific revision to use for the model code on
  45. Hugging Face Hub. It can be a branch name, a tag name, or a
  46. commit id. If unspecified, will use the default version.
  47. tokenizer_revision: The specific tokenizer version to use. It can be a
  48. branch name, a tag name, or a commit id. If unspecified, will use
  49. the default version.
  50. max_model_len: Maximum length of a sequence (including prompt and
  51. output). If None, will be derived from the model.
  52. quantization: Quantization method that was used to quantize the model
  53. weights. If None, we assume the model weights are not quantized.
  54. load_in_4bit: Whether to load the FP16 model in bitsandbytes 4bit
  55. format. Works with AWQ models as well as FP16.
  56. load_in_8bit: Whether to load the FP16 model in 8bit format. Slower
  57. than load_in_smooth in terms of throughput.
  58. load_in_smooth: Whether to load the FP16 model in smoothquant format.
  59. quantization_param_path: Path to JSON file containing scaling factors.
  60. Used to load KV cache scaling factors into the model when KV cache
  61. type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
  62. be used to load activation and weight scaling factors when the
  63. model dtype is FP8_E4M3 on ROCm.
  64. enforce_eager: Whether to enforce eager execution. If True, we will
  65. disable CUDA graph and always execute the model in eager mode.
  66. If False, we will use CUDA graph and eager execution in hybrid.
  67. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  68. When a sequence has context length larger than this, we fall back
  69. to eager mode.
  70. """
  71. def __init__(
  72. self,
  73. model: str,
  74. tokenizer: str,
  75. tokenizer_mode: str,
  76. trust_remote_code: bool,
  77. download_dir: Optional[str],
  78. load_format: str,
  79. # dtype: str,
  80. dtype: Union[str, torch.dtype],
  81. seed: int,
  82. revision: Optional[str] = None,
  83. code_revision: Optional[str] = None,
  84. tokenizer_revision: Optional[str] = None,
  85. max_model_len: Optional[int] = None,
  86. quantization: Optional[str] = None,
  87. load_in_4bit: bool = False,
  88. load_in_8bit: bool = False,
  89. load_in_smooth: bool = False,
  90. quantization_param_path: Optional[str] = None,
  91. enforce_eager: bool = True,
  92. max_context_len_to_capture: Optional[int] = None,
  93. max_log_probs: int = 10,
  94. ) -> None:
  95. self.model = model
  96. self.tokenizer = tokenizer
  97. self.tokenizer_mode = tokenizer_mode
  98. self.trust_remote_code = trust_remote_code
  99. self.download_dir = download_dir
  100. self.load_format = load_format
  101. self.seed = seed
  102. self.revision = revision
  103. self.code_revision = code_revision
  104. self.tokenizer_revision = tokenizer_revision
  105. self.quantization = quantization
  106. self.load_in_4bit = load_in_4bit
  107. self.load_in_8bit = load_in_8bit
  108. self.load_in_smooth = load_in_smooth
  109. self.quantization_param_path = quantization_param_path
  110. self.enforce_eager = enforce_eager
  111. self.max_context_len_to_capture = max_context_len_to_capture
  112. self.max_log_probs = max_log_probs
  113. if os.environ.get("APHRODITE_USE_MODELSCOPE",
  114. "False").lower() == "true":
  115. # download model from ModelScope hub,
  116. # lazy import so that modelscope is not required for normal use.
  117. from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
  118. if not os.path.exists(model):
  119. model_path = snapshot_download(model_id=model,
  120. cache_dir=download_dir,
  121. revision=revision)
  122. else:
  123. model_path = model
  124. self.model = model_path
  125. self.download_dir = model_path
  126. self.tokenizer = model_path
  127. self.hf_config = get_config(self.model, trust_remote_code, revision,
  128. code_revision)
  129. self.hf_text_config = get_hf_text_config(self.hf_config)
  130. self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
  131. self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
  132. max_model_len)
  133. self._verify_load_format()
  134. self._verify_tokenizer_mode()
  135. self._verify_quantization()
  136. self._verify_cuda_graph()
  137. def _verify_load_format(self) -> None:
  138. load_format = self.load_format.lower()
  139. supported_load_format = [
  140. "auto", "pt", "safetensors", "npcache", "dummy"
  141. ]
  142. rocm_not_supported_load_format = []
  143. if load_format not in supported_load_format:
  144. raise ValueError(
  145. f"Unknown load format: {self.load_format}. Must be one of "
  146. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  147. if is_hip() and load_format in rocm_not_supported_load_format:
  148. rocm_supported_load_format = [
  149. f for f in supported_load_format
  150. if (f not in rocm_not_supported_load_format)
  151. ]
  152. raise ValueError(
  153. f"load format \'{load_format}\' is not supported in ROCm. "
  154. f"Supported load format are "
  155. f"{rocm_supported_load_format}")
  156. # TODO: Remove this check once HF updates the pt weights of Mixtral.
  157. architectures = getattr(self.hf_config, "architectures", [])
  158. # architectures can be None instead of []
  159. if architectures and "MixtralForCausalLM" in architectures \
  160. and load_format == "pt":
  161. raise ValueError(
  162. "Currently, the 'pt' format is not supported for Mixtral. "
  163. "Please use the 'safetensors' format instead. ")
  164. self.load_format = load_format
  165. def _verify_tokenizer_mode(self) -> None:
  166. tokenizer_mode = self.tokenizer_mode.lower()
  167. if tokenizer_mode not in ["auto", "slow"]:
  168. raise ValueError(
  169. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  170. "either 'auto' or 'slow'.")
  171. self.tokenizer_mode = tokenizer_mode
  172. def _verify_quantization(self) -> None:
  173. supported_quantization = [
  174. "aqlm", "awq", "bnb", "eetq", "exl2", "gguf", "gptq", "quip",
  175. "squeezellm", "marlin"
  176. ]
  177. rocm_not_supported_quantization = ["aqlm", "awq", "bnb", "quip"]
  178. if self.quantization is not None:
  179. self.quantization = self.quantization.lower()
  180. if self.model.endswith("gguf"):
  181. if self.quantization is None:
  182. self.quantization = "gguf"
  183. elif self.quantization != "gguf":
  184. raise ValueError(
  185. f"GGUF file cannot be used in ({self.quantization}).")
  186. # Parse quantization method from the HF model config, if available.
  187. hf_quant_config = getattr(self.hf_config, "quantization_config", None)
  188. if hf_quant_config is not None:
  189. hf_quant_method = str(hf_quant_config["quant_method"]).lower()
  190. # If the GPTQ model is serialized in marlin format, use marlin.
  191. if (hf_quant_method == "gptq"
  192. and "is_marlin_format" in hf_quant_config
  193. and hf_quant_config["is_marlin_format"]):
  194. hf_quant_method = "marlin"
  195. if self.quantization is None:
  196. self.quantization = hf_quant_method
  197. elif self.quantization != hf_quant_method:
  198. raise ValueError(
  199. "Quantization method specified in the model config "
  200. f"({hf_quant_method}) does not match the quantization "
  201. f"method specified in the `quantization` argument "
  202. f"({self.quantization}).")
  203. if self.load_in_4bit:
  204. # the kernels seem to not work with 4bit weight_only
  205. if torch.cuda.get_device_capability(0)[0] < 8:
  206. raise ValueError(
  207. "load_in_4bit quantization is not supported on GPUs with "
  208. "compute capability less than 8.0.")
  209. if self.quantization is None:
  210. self.quantization = "bnb"
  211. self.hf_config.quantization_config = {
  212. "bits": 4,
  213. "quant_mode": "weight_only",
  214. "quant_method": "bnb",
  215. "group_size": 128,
  216. "zero_point": True,
  217. "from_float": True
  218. }
  219. elif self.quantization == "awq":
  220. logger.warning("AWQ model is being loaded in 4bit bnb format.")
  221. self.quantization = "bnb"
  222. self.hf_config.quantization_config = {
  223. "zero_point": True,
  224. "q_group_size": 128,
  225. "w_bit": 4,
  226. "version": "gemm"
  227. }
  228. elif self.quantization != "bnb":
  229. raise ValueError("4bit quantization is not supported in "
  230. f"{self.quantization}.")
  231. if self.load_in_8bit:
  232. if self.quantization is None:
  233. self.quantization = "bnb"
  234. elif self.quantization != "bnb":
  235. raise ValueError("8bit quantization is not supported in "
  236. f"{self.quantization}.")
  237. self.hf_config.quantization_config = {
  238. "bits": 8,
  239. "quant_mode": "llm_int8",
  240. "quant_method": "bnb",
  241. "group_size": 128,
  242. "zero_point": True,
  243. "from_float": True
  244. }
  245. self.enforce_eager = True
  246. if self.load_in_smooth:
  247. if self.quantization is None:
  248. self.quantization = "bnb"
  249. elif self.quantization != "bnb":
  250. raise ValueError("Smooth quantization is not supported in "
  251. f"{self.quantization}.")
  252. self.hf_config.quantization_config = {
  253. "bits": 8,
  254. "quant_mode": "smoothquant",
  255. "quant_method": "bnb",
  256. "group_size": 128,
  257. "zero_point": True,
  258. "from_float": True
  259. }
  260. self.enforce_eager = True
  261. if self.quantization is not None:
  262. if self.quantization not in supported_quantization:
  263. raise ValueError(
  264. f"Unknown quantization method: {self.quantization}. Must "
  265. f"be one of {supported_quantization}.")
  266. if is_hip(
  267. ) and self.quantization in rocm_not_supported_quantization:
  268. raise ValueError(
  269. f"{self.quantization} quantization is currently not "
  270. "supported in ROCm.")
  271. if self.quantization != "marlin":
  272. logger.warning(
  273. f"{self.quantization} quantization is not fully "
  274. "optimized yet. The speed can be slower than "
  275. "non-quantized models.")
  276. def _verify_cuda_graph(self) -> None:
  277. if self.max_context_len_to_capture is None:
  278. self.max_context_len_to_capture = self.max_model_len
  279. self.max_context_len_to_capture = min(self.max_context_len_to_capture,
  280. self.max_model_len)
  281. def verify_with_parallel_config(
  282. self,
  283. parallel_config: "ParallelConfig",
  284. ) -> None:
  285. total_num_attention_heads = self.hf_text_config.num_attention_heads
  286. tensor_parallel_size = parallel_config.tensor_parallel_size
  287. if total_num_attention_heads % tensor_parallel_size != 0:
  288. raise ValueError(
  289. f"Total number of attention heads ({total_num_attention_heads})"
  290. " must be divisible by tensor parallel size "
  291. f"({tensor_parallel_size}).")
  292. total_num_hidden_layers = self.hf_text_config.num_hidden_layers
  293. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  294. if total_num_hidden_layers % pipeline_parallel_size != 0:
  295. raise ValueError(
  296. f"Total number of hidden layers ({total_num_hidden_layers}) "
  297. "must be divisible by pipeline parallel size "
  298. f"({pipeline_parallel_size}).")
  299. def get_sliding_window(self) -> Optional[int]:
  300. if (hasattr(self.hf_text_config, "use_sliding_window")
  301. and not self.hf_text_config.use_sliding_window):
  302. return None
  303. return getattr(self.hf_text_config, "sliding_window", None)
  304. def get_vocab_size(self) -> int:
  305. return self.hf_text_config.vocab_size
  306. def get_hidden_size(self) -> int:
  307. return self.hf_text_config.hidden_size
  308. def get_head_size(self) -> int:
  309. if hasattr(self.hf_config, "head_dim"):
  310. return self.hf_config.head_dim
  311. # FIXME: This may not be true for all models.
  312. return (self.hf_text_config.hidden_size //
  313. self.hf_text_config.num_attention_heads)
  314. def get_total_num_kv_heads(self) -> int:
  315. """Returns the total number of KV heads."""
  316. # For GPTBigCode & Falcon:
  317. # NOTE: for falcon, when new_decoder_architecture is True, the
  318. # multi_query flag is ignored and we use n_head_kv for the number of
  319. # KV heads.
  320. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  321. new_decoder_arch_falcon = (
  322. self.hf_config.model_type in falcon_model_types
  323. and getattr(self.hf_config, "new_decoder_architecture", False))
  324. if not new_decoder_arch_falcon and getattr(self.hf_text_config,
  325. "multi_query", False):
  326. # Multi-query attention, only one KV head.
  327. # Currently, tensor parallelism is not supported in this case.
  328. return 1
  329. # For DBRX and MPT
  330. if self.hf_config.model_type in ["dbrx", "mpt"]:
  331. return getattr(self.hf_config.attn_config, "kv_n_heads",
  332. self.hf_config.num_attention_heads)
  333. attributes = [
  334. # For Falcon:
  335. "n_head_kv",
  336. "num_kv_heads",
  337. # For LLaMA-2:
  338. "num_key_value_heads",
  339. # For ChatGLM:
  340. "multi_query_group_num",
  341. ]
  342. for attr in attributes:
  343. num_kv_heads = getattr(self.hf_text_config, attr, None)
  344. if num_kv_heads is not None:
  345. return num_kv_heads
  346. # For non-grouped-query attention models, the number of KV heads is
  347. # equal to the number of attention heads.
  348. return self.hf_text_config.num_attention_heads
  349. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  350. """Returns the number of KV heads per GPU."""
  351. total_num_kv_heads = self.get_total_num_kv_heads()
  352. # If tensor parallelism is used, we divide the number of KV heads by
  353. # the tensor parallel size. We will replicate the KV heads in the
  354. # case where the number of KV heads is smaller than the tensor
  355. # parallel size so each GPU has at least one KV head.
  356. return max(1,
  357. total_num_kv_heads // parallel_config.tensor_parallel_size)
  358. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  359. total_num_hidden_layers = self.hf_text_config.num_hidden_layers
  360. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  361. class CacheConfig:
  362. """Configuration for the KV cache.
  363. Args:
  364. block_size: Size of a cache block in number of tokens.
  365. gpu_memory_utilization: Fraction of GPU memory to use for the
  366. Aphrodite execution.
  367. swap_space: Size of the CPU swap space per GPU (in GiB).
  368. cache_dtype: Data Type for KV cache storage.
  369. cache_quant_params_path: Path to the scales and zero points
  370. of KV cache quantization when cache_dtype is int8.
  371. num_gpu_blocks_override: Number of GPU blocks to use. This overrides
  372. the profiled num_gpu_blocks if specified. Does nothing if None.
  373. """
  374. def __init__(
  375. self,
  376. block_size: int,
  377. gpu_memory_utilization: float,
  378. swap_space: int,
  379. cache_dtype: str,
  380. # cache_quant_params_path: Optional[str] = None,
  381. num_gpu_blocks_override: Optional[int] = None,
  382. sliding_window: Optional[int] = None,
  383. context_shift: bool = False,
  384. ) -> None:
  385. self.block_size = block_size
  386. self.gpu_memory_utilization = gpu_memory_utilization
  387. self.swap_space_bytes = swap_space * _GB
  388. self.num_gpu_blocks_override = num_gpu_blocks_override
  389. self.cache_dtype = cache_dtype
  390. self.sliding_window = sliding_window
  391. # self.cache_quant_params_path = cache_quant_params_path
  392. self.context_shift = context_shift
  393. self._verify_args()
  394. self._verify_cache_dtype()
  395. # Will be set after profiling.
  396. self.num_gpu_blocks = None
  397. self.num_cpu_blocks = None
  398. def metrics_info(self):
  399. # convert cache_config to dict(key: str, value: str) for prometheus
  400. # metrics info
  401. return {key: str(value) for key, value in self.__dict__.items()}
  402. def _verify_args(self) -> None:
  403. if self.gpu_memory_utilization > 1.0:
  404. raise ValueError(
  405. "GPU memory utilization must be less than 1.0. Got "
  406. f"{self.gpu_memory_utilization}.")
  407. def _verify_cache_dtype(self) -> None:
  408. if self.cache_dtype == "auto":
  409. # if self.cache_dtype in ["auto", "int8"]:
  410. pass
  411. elif self.cache_dtype == "fp8":
  412. if not is_hip():
  413. nvcc_cuda_version = get_nvcc_cuda_version()
  414. if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
  415. raise ValueError(
  416. "FP8 is not supported when cuda version is"
  417. "lower than 11.8.")
  418. logger.info(
  419. "Using fp8 data type to store kv cache. It reduces the GPU "
  420. "memory footprint and boosts the performance. "
  421. "But it may cause slight accuracy drop without scaling "
  422. "factors. FP8_E5M2 (without scaling) is only supported on "
  423. "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
  424. "is instead supported for common inference criteria.")
  425. else:
  426. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  427. def verify_with_parallel_config(
  428. self,
  429. parallel_config: "ParallelConfig",
  430. ) -> None:
  431. total_cpu_memory = get_cpu_memory()
  432. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  433. # group are in the same node. However, the GPUs may span multiple nodes.
  434. num_gpus_per_node = parallel_config.tensor_parallel_size
  435. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  436. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  437. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  438. "allocated for the swap space.")
  439. if cpu_memory_usage > 0.7 * total_cpu_memory:
  440. raise ValueError("Too large swap space. " + msg)
  441. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  442. logger.warning("Possibly too large swap space. " + msg)
  443. @dataclass
  444. class TokenizerPoolConfig:
  445. """Configuration for the tokenizer pool.
  446. Args:
  447. pool_size: Number of tokenizer instances in the pool.
  448. pool_type: Type of the tokenizer pool.
  449. extra_config: Additional config for the pool.
  450. The way the config will be used depends on the
  451. pool type.
  452. """
  453. pool_size: int
  454. pool_type: str
  455. extra_config: dict
  456. def __post_init__(self):
  457. if self.pool_type not in ("ray", ):
  458. raise ValueError(f"Unknown pool type: {self.pool_type}.")
  459. if not isinstance(self.extra_config, dict):
  460. raise ValueError("extra_config must be a dictionary.")
  461. @classmethod
  462. def create_config(
  463. cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
  464. tokenizer_pool_extra_config: Optional[Union[str, dict]]
  465. ) -> Optional["TokenizerPoolConfig"]:
  466. """Create a TokenizerPoolConfig from the given parameters.
  467. If tokenizer_pool_size is 0, return None.
  468. Args:
  469. tokenizer_pool_size: Number of tokenizer workers in the pool.
  470. tokenizer_pool_type: Type of the tokenizer pool.
  471. tokenizer_pool_extra_config: Additional config for the pool.
  472. The way the config will be used depends on the pool type.
  473. """
  474. if tokenizer_pool_size:
  475. if isinstance(tokenizer_pool_extra_config, str):
  476. tokenizer_pool_extra_config_parsed = json.loads(
  477. tokenizer_pool_extra_config)
  478. else:
  479. tokenizer_pool_extra_config_parsed = (
  480. tokenizer_pool_extra_config or {})
  481. tokenizer_pool_config = cls(tokenizer_pool_size,
  482. tokenizer_pool_type,
  483. tokenizer_pool_extra_config_parsed)
  484. else:
  485. tokenizer_pool_config = None
  486. return tokenizer_pool_config
  487. class ParallelConfig:
  488. """Configuration for the distributed execution.
  489. Args:
  490. pipeline_parallel_size: Number of pipeline parallel groups.
  491. tensor_parallel_size: Number of tensor parallel groups.
  492. worker_use_ray: Whether to use Ray for model workers. Will be set to
  493. True if either pipeline_parallel_size or tensor_parallel_size is
  494. greater than 1.
  495. max_parallel_loading_workers: Maximum number of multiple batches
  496. when load model sequentially. To avoid RAM OOM when using tensor
  497. parallel and large models.
  498. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  499. fall back to NCCL.
  500. tokenizer_pool_config: Configuration for the tokenizer pool.
  501. If None, will use synchronous tokenization.
  502. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
  503. https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
  504. """
  505. def __init__(
  506. self,
  507. pipeline_parallel_size: int,
  508. tensor_parallel_size: int,
  509. worker_use_ray: bool,
  510. max_parallel_loading_workers: Optional[int] = None,
  511. disable_custom_all_reduce: bool = False,
  512. tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
  513. ray_workers_use_nsight: bool = False,
  514. placement_group: Optional["PlacementGroup"] = None,
  515. ) -> None:
  516. self.pipeline_parallel_size = pipeline_parallel_size
  517. self.tensor_parallel_size = tensor_parallel_size
  518. self.worker_use_ray = worker_use_ray
  519. self.max_parallel_loading_workers = max_parallel_loading_workers
  520. self.disable_custom_all_reduce = disable_custom_all_reduce
  521. self.tokenizer_pool_config = tokenizer_pool_config
  522. self.ray_workers_use_nsight = ray_workers_use_nsight
  523. self.placement_group = placement_group
  524. self.world_size = pipeline_parallel_size * self.tensor_parallel_size
  525. if self.world_size > 1:
  526. self.worker_use_ray = True
  527. self._verify_args()
  528. def _verify_args(self) -> None:
  529. if self.pipeline_parallel_size > 1:
  530. raise NotImplementedError(
  531. "Pipeline parallelism is not supported yet.")
  532. if not self.disable_custom_all_reduce and self.world_size > 1:
  533. if is_hip():
  534. self.disable_custom_all_reduce = True
  535. logger.info(
  536. "Disabled the custom all-reduce kernel because it is not "
  537. "supported on AMD GPUs.")
  538. elif self.pipeline_parallel_size > 1:
  539. self.disable_custom_all_reduce = True
  540. logger.info(
  541. "Disabled the custom all-reduce kernel because it is not "
  542. "supported with pipeline parallelism.")
  543. if self.ray_workers_use_nsight and not self.worker_use_ray:
  544. raise ValueError("Unable to use nsight profiling unless workers "
  545. "run with Ray.")
  546. class SchedulerConfig:
  547. """Scheduler configuration.
  548. Args:
  549. max_num_batched_tokens: Maximum number of tokens to be processed in
  550. a single iteration.
  551. max_num_seqs: Maximum number of sequences to be processed in a single
  552. iteration.
  553. max_model_len: Maximum length of a sequence (including prompt
  554. and generated text).
  555. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
  556. num_lookahead_slots: The number of slots to allocate per sequence per
  557. step, beyond the known token ids. This is used in speculative
  558. decoding to store KV activations of tokens which may or may not be
  559. accepted.
  560. delay_factor: Apply a delay (of delay factor multiplied by previous
  561. prompt latency) before scheduling the next prompt.
  562. policy: Policy of sequence scheduling (`fcfs` or `reorder`).
  563. reorder_window: Allowed reorder window size (in sec) for `reorder`
  564. policy.
  565. enable_chunked_prefill: If True, prefill requests can be chunked
  566. based on the remaining max_num_batched_tokens.
  567. """
  568. def __init__(
  569. self,
  570. max_num_batched_tokens: Optional[int],
  571. max_num_seqs: int,
  572. max_model_len: int,
  573. use_v2_block_manager: bool = False,
  574. num_lookahead_slots: int = 0,
  575. delay_factor: float = 0.0,
  576. policy: str = "fcfs",
  577. reorder_window: float = 0.0,
  578. enable_chunked_prefill: bool = False,
  579. ) -> None:
  580. if max_num_batched_tokens is not None:
  581. self.max_num_batched_tokens = max_num_batched_tokens
  582. else:
  583. if enable_chunked_prefill:
  584. # For chunked prefill, choose the well-tuned batch size.
  585. self.max_num_batched_tokens = 768
  586. else:
  587. # If max_model_len is too short, use 2048 as the default value
  588. # for higher throughput.
  589. self.max_num_batched_tokens = max(max_model_len, 2048)
  590. if enable_chunked_prefill:
  591. logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
  592. self.max_num_seqs = max_num_seqs
  593. self.max_model_len = max_model_len
  594. self.use_v2_block_manager = use_v2_block_manager
  595. self.num_lookahead_slots = num_lookahead_slots
  596. self.delay_factor = delay_factor
  597. self.policy = policy
  598. self.reorder_window = reorder_window
  599. self.chunked_prefill_enabled = enable_chunked_prefill
  600. self._verify_args()
  601. def _verify_args(self) -> None:
  602. if (self.max_num_batched_tokens < self.max_model_len
  603. and not self.chunked_prefill_enabled):
  604. raise ValueError(
  605. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  606. f"smaller than max_model_len ({self.max_model_len}). "
  607. "This effectively limits the maximum sequence length to "
  608. "max_num_batched_tokens and makes Aphrodite reject longer "
  609. "sequences. Please increase max_num_batched_tokens or "
  610. "decrease max_model_len.")
  611. if self.max_num_batched_tokens < self.max_num_seqs:
  612. raise ValueError(
  613. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  614. "be greater than or equal to max_num_seqs "
  615. f"({self.max_num_seqs}).")
  616. if self.reorder_window < 0:
  617. raise ValueError(f"reorder_window ({self.reorder_window}) must "
  618. "be not be negative.")
  619. if self.reorder_window != 0 and self.policy != 'reorder':
  620. raise ValueError("fcfs policy doesn't support reorder_window "
  621. f"({self.reorder_window}).")
  622. if self.num_lookahead_slots < 0:
  623. raise ValueError(
  624. "num_lookahead_slots "
  625. f"({self.num_lookahead_slots}) must be greater than or "
  626. "equal to 0.")
  627. class DeviceConfig:
  628. def __init__(self, device: str = "auto") -> None:
  629. if device == "auto":
  630. # Automated device type detection
  631. if torch.cuda.is_available():
  632. self.device_type = "cuda"
  633. elif is_neuron():
  634. self.device_type = "neuron"
  635. elif is_cpu():
  636. self.device_type = "cpu"
  637. else:
  638. raise RuntimeError("No supported device detected.")
  639. else:
  640. # Device type is assigned explicitly
  641. self.device_type = device
  642. # Some device types require processing inputs on CPU
  643. if self.device_type in ["neuron"]:
  644. self.device = torch.device("cpu")
  645. else:
  646. # Set device with device type
  647. self.device = torch.device(self.device_type)
  648. class SpeculativeConfig:
  649. """Configuration for speculative decoding.
  650. The configuration is currently specialized to draft-model speculative
  651. decoding with top-1 proposals.
  652. """
  653. @staticmethod
  654. def maybe_create_spec_config(
  655. target_model_config: ModelConfig,
  656. target_parallel_config: ParallelConfig,
  657. target_dtype: str,
  658. speculative_model: Optional[str],
  659. num_speculative_tokens: Optional[int],
  660. ) -> Optional["SpeculativeConfig"]:
  661. """Create a SpeculativeConfig if possible, else return None.
  662. This function attempts to create a SpeculativeConfig object based on the
  663. provided parameters. If the necessary conditions are met, it returns an
  664. instance of SpeculativeConfig. Otherwise, it returns None.
  665. Args:
  666. target_model_config (ModelConfig): The configuration of the target
  667. model.
  668. target_parallel_config (ParallelConfig): The parallel configuration
  669. for the target model.
  670. target_dtype (str): The data type used for the target model.
  671. speculative_model (Optional[str]): The name of the speculative
  672. model, if provided.
  673. num_speculative_tokens (Optional[int]): The number of speculative
  674. tokens, if provided.
  675. Returns:
  676. Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
  677. the necessary conditions are met, else None.
  678. """
  679. if (speculative_model is None and num_speculative_tokens is None):
  680. return None
  681. if speculative_model is not None and num_speculative_tokens is None:
  682. raise ValueError(
  683. "Expected both speculative_model and "
  684. "num_speculative_tokens to be provided, but found "
  685. f"{speculative_model=} and {num_speculative_tokens=}.")
  686. # TODO: The user should be able to specify revision/quantization/max
  687. # model len for the draft model. It is not currently supported.
  688. draft_revision = None
  689. draft_code_revision = None
  690. draft_quantization = None
  691. draft_max_model_len = None
  692. draft_model_config = ModelConfig(
  693. model=speculative_model,
  694. tokenizer=target_model_config.tokenizer,
  695. tokenizer_mode=target_model_config.tokenizer_mode,
  696. trust_remote_code=target_model_config.trust_remote_code,
  697. download_dir=target_model_config.download_dir,
  698. load_format=target_model_config.load_format,
  699. dtype=target_model_config.dtype,
  700. seed=target_model_config.seed,
  701. revision=draft_revision,
  702. code_revision=draft_code_revision,
  703. tokenizer_revision=target_model_config.tokenizer_revision,
  704. max_model_len=draft_max_model_len,
  705. quantization=draft_quantization,
  706. enforce_eager=target_model_config.enforce_eager,
  707. max_context_len_to_capture=target_model_config.
  708. max_context_len_to_capture,
  709. max_log_probs=target_model_config.max_log_probs,
  710. )
  711. draft_parallel_config = (
  712. SpeculativeConfig.create_draft_parallel_config(
  713. target_parallel_config))
  714. return SpeculativeConfig(
  715. draft_model_config,
  716. draft_parallel_config,
  717. num_speculative_tokens,
  718. )
  719. @staticmethod
  720. def create_draft_parallel_config(
  721. target_parallel_config: ParallelConfig) -> ParallelConfig:
  722. """Create a parallel config for use by the draft worker.
  723. This is mostly a copy of the target parallel config. In the future the
  724. draft worker can have a different parallel strategy, e.g. TP=1.
  725. """
  726. draft_parallel_config = ParallelConfig(
  727. pipeline_parallel_size=target_parallel_config.
  728. pipeline_parallel_size,
  729. tensor_parallel_size=target_parallel_config.tensor_parallel_size,
  730. worker_use_ray=target_parallel_config.worker_use_ray,
  731. max_parallel_loading_workers=target_parallel_config.
  732. max_parallel_loading_workers,
  733. disable_custom_all_reduce=target_parallel_config.
  734. disable_custom_all_reduce,
  735. tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
  736. ray_workers_use_nsight=target_parallel_config.
  737. ray_workers_use_nsight,
  738. placement_group=target_parallel_config.placement_group,
  739. )
  740. return draft_parallel_config
  741. def __init__(
  742. self,
  743. draft_model_config: ModelConfig,
  744. draft_parallel_config: ParallelConfig,
  745. num_speculative_tokens: int,
  746. ):
  747. """Create a SpeculativeConfig object.
  748. Args:
  749. draft_model_config: ModelConfig for the draft model.
  750. draft_parallel_config: ParallelConfig for the draft model.
  751. num_speculative_tokens: The number of tokens to sample from the
  752. draft model before scoring with the target model.
  753. """
  754. self.draft_model_config = draft_model_config
  755. self.draft_parallel_config = draft_parallel_config
  756. self.num_speculative_tokens = num_speculative_tokens
  757. self._verify_args()
  758. def _verify_args(self) -> None:
  759. if self.num_speculative_tokens <= 0:
  760. raise ValueError("Expected num_speculative_tokens to be greater "
  761. f"than zero ({self.num_speculative_tokens}).")
  762. if self.draft_model_config:
  763. self.draft_model_config.verify_with_parallel_config(
  764. self.draft_parallel_config)
  765. @property
  766. def num_lookahead_slots(self) -> int:
  767. """The number of additional slots the scheduler should allocate per
  768. step, in addition to the slots allocated for each known token.
  769. This is equal to the number of speculative tokens, as each speculative
  770. token must be scored.
  771. """
  772. return self.num_speculative_tokens
  773. def __repr__(self) -> str:
  774. draft_model = self.draft_model_config.model
  775. num_spec_tokens = self.num_speculative_tokens
  776. return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
  777. @dataclass
  778. class LoRAConfig:
  779. max_lora_rank: int
  780. max_loras: int
  781. max_cpu_loras: Optional[int] = None
  782. lora_dtype: Optional[torch.dtype] = None
  783. lora_extra_vocab_size: int = 256
  784. # This is a constant.
  785. lora_vocab_padding_size: ClassVar[int] = 256
  786. def __post_init__(self):
  787. # Keep this in sync with kernels/punica/bgmv/bgmv_config.h
  788. possible_max_ranks = (8, 16, 32, 64)
  789. possible_lora_extra_vocab_size = (0, 256, 512)
  790. if self.max_lora_rank not in possible_max_ranks:
  791. raise ValueError(
  792. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  793. f"{possible_max_ranks}.")
  794. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  795. raise ValueError(
  796. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  797. f"must be one of {possible_lora_extra_vocab_size}.")
  798. if self.max_loras < 1:
  799. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  800. if self.max_cpu_loras is None:
  801. self.max_cpu_loras = self.max_loras
  802. elif self.max_cpu_loras < self.max_loras:
  803. raise ValueError(
  804. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  805. f"max_loras ({self.max_loras})")
  806. def verify_with_model_config(self, model_config: ModelConfig):
  807. if self.lora_dtype in (None, "auto"):
  808. self.lora_dtype = model_config.dtype
  809. elif isinstance(self.lora_dtype, str):
  810. self.lora_dtype = getattr(torch, self.lora_dtype)
  811. if (model_config.quantization is not None
  812. and model_config.quantization == "gguf"):
  813. raise ValueError("LoRA is not supported with GGUF quantization.")
  814. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  815. if scheduler_config.max_num_batched_tokens > 65528:
  816. raise ValueError(
  817. "Due to limitations of the custom LoRA CUDA kernel, "
  818. "max_num_batched_tokens must be <= 65528 when "
  819. "LoRA is enabled.")
  820. @dataclass
  821. class VisionLanguageConfig:
  822. """Configs the input data format and how models should run for
  823. vision language models."""
  824. class ImageInputType(enum.Enum):
  825. """Image input type into the vision language model.
  826. An image roughly goes through the following transformation:
  827. Raw image --> pixel values --> image features --> image embeddings.
  828. The difference between different image input types is where the
  829. image encoder (pixel values --> image features) is run.
  830. Different image input types also correspond to different tensor shapes.
  831. For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
  832. IMAGE_FEATURES: (1, 576, 1024).
  833. """
  834. PIXEL_VALUES = enum.auto()
  835. IMAGE_FEATURES = enum.auto()
  836. image_input_type: ImageInputType
  837. # The input id corresponding to image token.
  838. image_token_id: int
  839. # Used for running `run_prefill_max_token`.
  840. # For models that support varying resolution, this corresponds to
  841. # worst case scenario (biggest supported resolution).
  842. image_input_shape: tuple
  843. image_feature_size: int
  844. @classmethod
  845. def get_image_input_enum_type(
  846. cls, value: str) -> "VisionLanguageConfig.ImageInputType":
  847. """Get the image input type from a string."""
  848. try:
  849. return cls.ImageInputType[value.upper()]
  850. except KeyError as e:
  851. raise ValueError(f"{value} is not a valid choice. "
  852. f"Expecting to choose from "
  853. f"{[x.name for x in cls.ImageInputType]}.") from e
  854. _STR_DTYPE_TO_TORCH_DTYPE = {
  855. "half": torch.float16,
  856. "float16": torch.float16,
  857. "float": torch.float32,
  858. "float32": torch.float32,
  859. "bfloat16": torch.bfloat16,
  860. }
  861. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  862. def _get_and_verify_dtype(
  863. config: PretrainedConfig,
  864. dtype: Union[str, torch.dtype],
  865. ) -> torch.dtype:
  866. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  867. # because config.torch_dtype can be None.
  868. config_dtype = getattr(config, "torch_dtype", None)
  869. if config_dtype is None:
  870. config_dtype = torch.float32
  871. if isinstance(dtype, str):
  872. dtype = dtype.lower()
  873. if dtype == "auto":
  874. if config_dtype == torch.float32:
  875. # Following the common practice, we use float16 for float32
  876. # models.
  877. torch_dtype = torch.float16
  878. else:
  879. torch_dtype = config_dtype
  880. else:
  881. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  882. raise ValueError(f"Unknown dtype: {dtype}")
  883. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  884. elif isinstance(dtype, torch.dtype):
  885. torch_dtype = dtype
  886. else:
  887. raise ValueError(f"Unknown dtype: {dtype}")
  888. if is_hip() and torch_dtype == torch.float32:
  889. rocm_supported_dtypes = [
  890. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  891. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  892. ]
  893. raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
  894. f"Supported dtypes are {rocm_supported_dtypes}")
  895. # Verify the dtype.
  896. if torch_dtype != config_dtype:
  897. if torch_dtype == torch.float32:
  898. # Upcasting to float32 is allowed.
  899. pass
  900. elif config_dtype == torch.float32:
  901. # Downcasting from float32 to float16 or bfloat16 is allowed.
  902. pass
  903. else:
  904. # Casting between float16 and bfloat16 is allowed with a warning.
  905. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  906. return torch_dtype
  907. def _get_and_verify_max_len(
  908. hf_config: PretrainedConfig,
  909. max_model_len: Optional[int],
  910. ) -> int:
  911. """Get and verify the model's maximum length."""
  912. derived_max_model_len = float("inf")
  913. possible_keys = [
  914. # OPT
  915. "max_position_embeddings",
  916. # GPT-2
  917. "n_positions",
  918. # MPT
  919. "max_seq_len",
  920. # ChatGLM2
  921. "seq_length",
  922. # Others
  923. "max_sequence_length",
  924. "max_seq_length",
  925. "seq_len",
  926. "n_ctx",
  927. ]
  928. for key in possible_keys:
  929. max_len_key = getattr(hf_config, key, None)
  930. if max_len_key is not None:
  931. derived_max_model_len = min(derived_max_model_len, max_len_key)
  932. if derived_max_model_len == float("inf"):
  933. if max_model_len is not None:
  934. # If max_model_len is specified, we use it.
  935. return max_model_len
  936. default_max_len = 2048
  937. logger.warning(
  938. "The model's config.json does not contain any of the following "
  939. "keys to determine the original maximum length of the model: "
  940. f"{possible_keys}. Assuming the model's maximum length is "
  941. f"{default_max_len}.")
  942. derived_max_model_len = default_max_len
  943. rope_scaling = getattr(hf_config, "rope_scaling", None)
  944. if rope_scaling is not None:
  945. assert "factor" in rope_scaling
  946. scaling_factor = rope_scaling["factor"]
  947. if rope_scaling["type"] == "yarn":
  948. derived_max_model_len = rope_scaling[
  949. "original_max_position_embeddings"]
  950. derived_max_model_len *= scaling_factor
  951. if max_model_len is None:
  952. max_model_len = derived_max_model_len
  953. elif max_model_len > derived_max_model_len:
  954. # hope this works
  955. scaling_factor = max_model_len / derived_max_model_len
  956. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  957. logger.warning(
  958. f"User-specified max_model_len {max_model_len} is higher than "
  959. f"the original {derived_max_model_len}. "
  960. "Attempting to use RoPE scaling.")
  961. derived_max_model_len = max_model_len
  962. return int(max_model_len)
  963. @dataclass(frozen=True)
  964. class EngineConfig:
  965. """Dataclass which contains all engine-related configuration. This
  966. simplifies passing around the distinct configurations in the codebase.
  967. """
  968. model_config: ModelConfig
  969. cache_config: CacheConfig
  970. parallel_config: ParallelConfig
  971. scheduler_config: SchedulerConfig
  972. device_config: DeviceConfig
  973. lora_config: Optional[LoRAConfig]
  974. vision_language_config: Optional[VisionLanguageConfig]
  975. speculative_config: Optional[SpeculativeConfig]
  976. def __post_init__(self):
  977. """Verify configs are valid & consistent with each other.
  978. """
  979. self.model_config.verify_with_parallel_config(self.parallel_config)
  980. self.cache_config.verify_with_parallel_config(self.parallel_config)
  981. if self.lora_config:
  982. self.lora_config.verify_with_model_config(self.model_config)
  983. self.lora_config.verify_with_scheduler_config(
  984. self.scheduler_config)
  985. def to_dict(self):
  986. """Return the configs as a dictionary, for use in **kwargs.
  987. """
  988. return dict(
  989. (field.name, getattr(self, field.name)) for field in fields(self))