config.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. import enum
  2. import json
  3. import os
  4. from dataclasses import dataclass, field, fields
  5. from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
  6. import torch
  7. from loguru import logger
  8. from packaging.version import Version
  9. from transformers import PretrainedConfig
  10. from aphrodite.common.utils import (get_cpu_memory, get_nvcc_cuda_version,
  11. is_cpu, is_hip, is_neuron)
  12. from aphrodite.quantization import QUANTIZATION_METHODS
  13. from aphrodite.transformers_utils.config import get_config, get_hf_text_config
  14. if TYPE_CHECKING:
  15. from ray.util.placement_group import PlacementGroup
  16. from aphrodite.modeling.model_loader.loader import BaseModelLoader
  17. # If true, will load models from ModelScope instead of Hugging Face Hub.
  18. APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
  19. "False").lower() == "true"
  20. _GB = 1 << 30
  21. class ModelConfig:
  22. """Configuration for the model.
  23. Args:
  24. model: Name or path of the huggingface model to use.
  25. tokenizer: Name or path of the huggingface tokenizer to use.
  26. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  27. available, and "slow" will always use the slow tokenizer.
  28. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  29. downloading the model and tokenizer.
  30. dtype: Data type for model weights and activations. The "auto" option
  31. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  32. for BF16 models.
  33. seed: Random seed for reproducibility.
  34. revision: The specific model version to use. It can be a branch name,
  35. a tag name, or a commit id. If unspecified, will use the default
  36. version.
  37. code_revision: The specific revision to use for the model code on
  38. Hugging Face Hub. It can be a branch name, a tag name, or a
  39. commit id. If unspecified, will use the default version.
  40. tokenizer_revision: The specific tokenizer version to use. It can be a
  41. branch name, a tag name, or a commit id. If unspecified, will use
  42. the default version.
  43. max_model_len: Maximum length of a sequence (including prompt and
  44. output). If None, will be derived from the model.
  45. quantization: Quantization method that was used to quantize the model
  46. weights. If None, we assume the model weights are not quantized.
  47. load_in_4bit: Whether to load the FP16 model in bitsandbytes 4bit
  48. format. Works with AWQ models as well as FP16.
  49. load_in_8bit: Whether to load the FP16 model in 8bit format. Slower
  50. than load_in_smooth in terms of throughput.
  51. load_in_smooth: Whether to load the FP16 model in smoothquant format.
  52. quantization_param_path: Path to JSON file containing scaling factors.
  53. Used to load KV cache scaling factors into the model when KV cache
  54. type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
  55. be used to load activation and weight scaling factors when the
  56. model dtype is FP8_E4M3 on ROCm.
  57. enforce_eager: Whether to enforce eager execution. If True, we will
  58. disable CUDA graph and always execute the model in eager mode.
  59. If False, we will use CUDA graph and eager execution in hybrid.
  60. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  61. When a sequence has context length larger than this, we fall back
  62. to eager mode.
  63. skip_tokenizer_init: If true, skip initialization of tokenizer and
  64. detokenizer.
  65. """
  66. def __init__(
  67. self,
  68. model: str,
  69. tokenizer: str,
  70. tokenizer_mode: str,
  71. trust_remote_code: bool,
  72. dtype: Union[str, torch.dtype],
  73. seed: int,
  74. revision: Optional[str] = None,
  75. code_revision: Optional[str] = None,
  76. tokenizer_revision: Optional[str] = None,
  77. max_model_len: Optional[int] = None,
  78. quantization: Optional[str] = None,
  79. load_in_4bit: bool = False,
  80. load_in_8bit: bool = False,
  81. load_in_smooth: bool = False,
  82. quantization_param_path: Optional[str] = None,
  83. enforce_eager: bool = False,
  84. max_context_len_to_capture: Optional[int] = None,
  85. max_logprobs: int = 5,
  86. skip_tokenizer_init: bool = False,
  87. ) -> None:
  88. self.model = model
  89. self.tokenizer = tokenizer
  90. self.tokenizer_mode = tokenizer_mode
  91. self.trust_remote_code = trust_remote_code
  92. self.seed = seed
  93. self.revision = revision
  94. self.code_revision = code_revision
  95. self.tokenizer_revision = tokenizer_revision
  96. self.quantization = quantization
  97. self.load_in_4bit = load_in_4bit
  98. self.load_in_8bit = load_in_8bit
  99. self.load_in_smooth = load_in_smooth
  100. self.quantization_param_path = quantization_param_path
  101. self.enforce_eager = enforce_eager
  102. self.max_context_len_to_capture = max_context_len_to_capture
  103. self.max_logprobs = max_logprobs
  104. self.skip_tokenizer_init = skip_tokenizer_init
  105. self.hf_config = get_config(self.model, trust_remote_code, revision,
  106. code_revision)
  107. self.hf_text_config = get_hf_text_config(self.hf_config)
  108. self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
  109. self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
  110. max_model_len)
  111. if not self.skip_tokenizer_init:
  112. self._verify_tokenizer_mode()
  113. self._verify_quantization()
  114. self._verify_cuda_graph()
  115. def _verify_tokenizer_mode(self) -> None:
  116. tokenizer_mode = self.tokenizer_mode.lower()
  117. if tokenizer_mode not in ["auto", "slow"]:
  118. raise ValueError(
  119. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  120. "either 'auto' or 'slow'.")
  121. self.tokenizer_mode = tokenizer_mode
  122. def _verify_quantization(self) -> None:
  123. supported_quantization = [*QUANTIZATION_METHODS]
  124. rocm_supported_quantization = ["gptq", "squeezellm"]
  125. if self.quantization is not None:
  126. self.quantization = self.quantization.lower()
  127. # Parse quantization method from the HF model config, if available.
  128. quant_cfg = getattr(self.hf_config, "quantization_config", None)
  129. if quant_cfg is not None:
  130. quant_method = quant_cfg.get("quant_method", "").lower()
  131. # compat: autogptq >=0.8.0 use checkpoint_format: str
  132. # compat: autogptq <=0.7.1 is_marlin_format: bool
  133. is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
  134. or quant_cfg.get("is_marlin_format", False))
  135. # Use marlin if the GPTQ model is serialized in marlin format.
  136. if quant_method == "gptq" and is_format_marlin:
  137. logger.info("The model is serialized in Marlin format. "
  138. "Using Marlin kernel.")
  139. quant_method = "marlin"
  140. if self.quantization == "gptq":
  141. self.quantization = quant_method
  142. if self.quantization is None:
  143. self.quantization = quant_method
  144. elif self.quantization != quant_method:
  145. raise ValueError(
  146. "Quantization method specified in the model config "
  147. f"({quant_method}) does not match the quantization "
  148. f"method specified in the `quantization` argument "
  149. f"({self.quantization}).")
  150. if self.load_in_4bit:
  151. # the kernels seem to not work with 4bit weight_only
  152. if torch.cuda.get_device_capability(0)[0] < 8:
  153. raise ValueError(
  154. "load_in_4bit quantization is not supported on GPUs with "
  155. "compute capability less than 8.0.")
  156. if self.quantization is None:
  157. self.quantization = "bnb"
  158. self.hf_config.quantization_config = {
  159. "bits": 4,
  160. "quant_mode": "weight_only",
  161. "quant_method": "bnb",
  162. "group_size": 128,
  163. "zero_point": True,
  164. "from_float": True
  165. }
  166. elif self.quantization == "awq":
  167. logger.warning("AWQ model is being loaded in 4bit bnb format.")
  168. self.quantization = "bnb"
  169. self.hf_config.quantization_config = {
  170. "zero_point": True,
  171. "q_group_size": 128,
  172. "w_bit": 4,
  173. "version": "gemm"
  174. }
  175. elif self.quantization != "bnb":
  176. raise ValueError("4bit quantization is not supported in "
  177. f"{self.quantization}.")
  178. if self.load_in_8bit:
  179. if self.quantization is None:
  180. self.quantization = "bnb"
  181. elif self.quantization != "bnb":
  182. raise ValueError("8bit quantization is not supported in "
  183. f"{self.quantization}.")
  184. self.hf_config.quantization_config = {
  185. "bits": 8,
  186. "quant_mode": "llm_int8",
  187. "quant_method": "bnb",
  188. "group_size": 128,
  189. "zero_point": True,
  190. "from_float": True
  191. }
  192. self.enforce_eager = True
  193. if self.load_in_smooth:
  194. if self.quantization is None:
  195. self.quantization = "bnb"
  196. elif self.quantization != "bnb":
  197. raise ValueError("Smooth quantization is not supported in "
  198. f"{self.quantization}.")
  199. self.hf_config.quantization_config = {
  200. "bits": 8,
  201. "quant_mode": "smoothquant",
  202. "quant_method": "bnb",
  203. "group_size": 128,
  204. "zero_point": True,
  205. "from_float": True
  206. }
  207. self.enforce_eager = True
  208. if self.quantization is not None:
  209. if self.quantization not in supported_quantization:
  210. raise ValueError(
  211. f"Unknown quantization method: {self.quantization}. Must "
  212. f"be one of {supported_quantization}.")
  213. if is_hip(
  214. ) and self.quantization not in rocm_supported_quantization:
  215. raise ValueError(
  216. f"{self.quantization} quantization is currently not "
  217. "supported in ROCm.")
  218. if self.quantization != "marlin":
  219. logger.warning(
  220. f"{self.quantization} quantization is not fully "
  221. "optimized yet. The speed can be slower than "
  222. "non-quantized models.")
  223. def _verify_cuda_graph(self) -> None:
  224. if self.max_context_len_to_capture is None:
  225. self.max_context_len_to_capture = self.max_model_len
  226. self.max_context_len_to_capture = min(self.max_context_len_to_capture,
  227. self.max_model_len)
  228. def verify_with_parallel_config(
  229. self,
  230. parallel_config: "ParallelConfig",
  231. ) -> None:
  232. total_num_attention_heads = self.hf_text_config.num_attention_heads
  233. tensor_parallel_size = parallel_config.tensor_parallel_size
  234. if total_num_attention_heads % tensor_parallel_size != 0:
  235. raise ValueError(
  236. f"Total number of attention heads ({total_num_attention_heads})"
  237. " must be divisible by tensor parallel size "
  238. f"({tensor_parallel_size}).")
  239. total_num_hidden_layers = self.hf_text_config.num_hidden_layers
  240. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  241. if total_num_hidden_layers % pipeline_parallel_size != 0:
  242. raise ValueError(
  243. f"Total number of hidden layers ({total_num_hidden_layers}) "
  244. "must be divisible by pipeline parallel size "
  245. f"({pipeline_parallel_size}).")
  246. def get_sliding_window(self) -> Optional[int]:
  247. """Get the sliding window size, or None if disabled.
  248. """
  249. # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
  250. # addition to sliding window size. We check if that field is present
  251. # and if it's False, return None.
  252. if (hasattr(self.hf_text_config, "use_sliding_window")
  253. and not self.hf_text_config.use_sliding_window):
  254. return None
  255. return getattr(self.hf_text_config, "sliding_window", None)
  256. def get_vocab_size(self) -> int:
  257. return self.hf_text_config.vocab_size
  258. def get_hidden_size(self) -> int:
  259. return self.hf_text_config.hidden_size
  260. def get_head_size(self) -> int:
  261. if hasattr(self.hf_text_config, "head_dim"):
  262. return self.hf_text_config.head_dim
  263. # FIXME: This may not be true for all models.
  264. return (self.hf_text_config.hidden_size //
  265. self.hf_text_config.num_attention_heads)
  266. def get_total_num_kv_heads(self) -> int:
  267. """Returns the total number of KV heads."""
  268. # For GPTBigCode & Falcon:
  269. # NOTE: for falcon, when new_decoder_architecture is True, the
  270. # multi_query flag is ignored and we use n_head_kv for the number of
  271. # KV heads.
  272. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  273. new_decoder_arch_falcon = (
  274. self.hf_config.model_type in falcon_model_types
  275. and getattr(self.hf_config, "new_decoder_architecture", False))
  276. if not new_decoder_arch_falcon and getattr(self.hf_text_config,
  277. "multi_query", False):
  278. # Multi-query attention, only one KV head.
  279. # Currently, tensor parallelism is not supported in this case.
  280. return 1
  281. # For DBRX and MPT
  282. if self.hf_config.model_type in ["dbrx", "mpt"]:
  283. return getattr(self.hf_config.attn_config, "kv_n_heads",
  284. self.hf_config.num_attention_heads)
  285. attributes = [
  286. # For Falcon:
  287. "n_head_kv",
  288. "num_kv_heads",
  289. # For LLaMA-2:
  290. "num_key_value_heads",
  291. # For ChatGLM:
  292. "multi_query_group_num",
  293. ]
  294. for attr in attributes:
  295. num_kv_heads = getattr(self.hf_text_config, attr, None)
  296. if num_kv_heads is not None:
  297. return num_kv_heads
  298. # For non-grouped-query attention models, the number of KV heads is
  299. # equal to the number of attention heads.
  300. return self.hf_text_config.num_attention_heads
  301. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  302. """Returns the number of KV heads per GPU."""
  303. total_num_kv_heads = self.get_total_num_kv_heads()
  304. # If tensor parallelism is used, we divide the number of KV heads by
  305. # the tensor parallel size. We will replicate the KV heads in the
  306. # case where the number of KV heads is smaller than the tensor
  307. # parallel size so each GPU has at least one KV head.
  308. return max(1,
  309. total_num_kv_heads // parallel_config.tensor_parallel_size)
  310. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  311. total_num_hidden_layers = self.hf_text_config.num_hidden_layers
  312. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  313. class CacheConfig:
  314. """Configuration for the KV cache.
  315. Args:
  316. block_size: Size of a cache block in number of tokens.
  317. gpu_memory_utilization: Fraction of GPU memory to use for the
  318. Aphrodite execution.
  319. swap_space: Size of the CPU swap space per GPU (in GiB).
  320. cache_dtype: Data type for kv cache storage.
  321. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
  322. profiled num_gpu_blocks if specified. Does nothing if None.
  323. """
  324. def __init__(
  325. self,
  326. block_size: int,
  327. gpu_memory_utilization: float,
  328. swap_space: int,
  329. cache_dtype: str,
  330. num_gpu_blocks_override: Optional[int] = None,
  331. sliding_window: Optional[int] = None,
  332. enable_prefix_caching: bool = False,
  333. ) -> None:
  334. self.block_size = block_size
  335. self.gpu_memory_utilization = gpu_memory_utilization
  336. self.swap_space_bytes = swap_space * _GB
  337. self.num_gpu_blocks_override = num_gpu_blocks_override
  338. self.cache_dtype = cache_dtype
  339. self.sliding_window = sliding_window
  340. self.enable_prefix_caching = enable_prefix_caching
  341. self._verify_args()
  342. self._verify_cache_dtype()
  343. # Will be set after profiling.
  344. self.num_gpu_blocks = None
  345. self.num_cpu_blocks = None
  346. def metrics_info(self):
  347. # convert cache_config to dict(key: str, value: str) for prometheus
  348. # metrics info
  349. return {key: str(value) for key, value in self.__dict__.items()}
  350. def _verify_args(self) -> None:
  351. if self.gpu_memory_utilization > 1.0:
  352. raise ValueError(
  353. "GPU memory utilization must be less than 1.0. Got "
  354. f"{self.gpu_memory_utilization}.")
  355. def _verify_cache_dtype(self) -> None:
  356. if self.cache_dtype == "auto":
  357. pass
  358. elif self.cache_dtype == "fp8":
  359. if not is_hip():
  360. nvcc_cuda_version = get_nvcc_cuda_version()
  361. if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
  362. raise ValueError(
  363. "FP8 is not supported when cuda version is"
  364. "lower than 11.8.")
  365. logger.info(
  366. "Using fp8 data type to store kv cache. It reduces the GPU "
  367. "memory footprint and boosts the performance. "
  368. "But it may cause slight accuracy drop without scaling "
  369. "factors. FP8_E5M2 (without scaling) is only supported on "
  370. "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
  371. "is instead supported for common inference criteria.")
  372. else:
  373. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  374. def verify_with_parallel_config(
  375. self,
  376. parallel_config: "ParallelConfig",
  377. ) -> None:
  378. total_cpu_memory = get_cpu_memory()
  379. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  380. # group are in the same node. However, the GPUs may span multiple nodes.
  381. num_gpus_per_node = parallel_config.tensor_parallel_size
  382. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  383. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  384. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  385. "allocated for the swap space.")
  386. if cpu_memory_usage > 0.7 * total_cpu_memory:
  387. raise ValueError("Too large swap space. " + msg)
  388. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  389. logger.warning("Possibly too large swap space. " + msg)
  390. @dataclass
  391. class TokenizerPoolConfig:
  392. """Configuration for the tokenizer pool.
  393. Args:
  394. pool_size: Number of tokenizer workers in the pool.
  395. pool_type: Type of the pool.
  396. extra_config: Additional config for the pool.
  397. The way the config will be used depends on the
  398. pool type.
  399. """
  400. pool_size: int
  401. pool_type: str
  402. extra_config: dict
  403. def __post_init__(self):
  404. if self.pool_type not in ("ray", ):
  405. raise ValueError(f"Unknown pool type: {self.pool_type}")
  406. if not isinstance(self.extra_config, dict):
  407. raise ValueError("extra_config must be a dictionary.")
  408. @classmethod
  409. def create_config(
  410. cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
  411. tokenizer_pool_extra_config: Optional[Union[str, dict]]
  412. ) -> Optional["TokenizerPoolConfig"]:
  413. """Create a TokenizerPoolConfig from the given parameters.
  414. If tokenizer_pool_size is 0, return None.
  415. Args:
  416. tokenizer_pool_size: Number of tokenizer workers in the pool.
  417. tokenizer_pool_type: Type of the pool.
  418. tokenizer_pool_extra_config: Additional config for the pool.
  419. The way the config will be used depends on the
  420. pool type. This can be a JSON string (will be parsed).
  421. """
  422. if tokenizer_pool_size:
  423. if isinstance(tokenizer_pool_extra_config, str):
  424. tokenizer_pool_extra_config_parsed = json.loads(
  425. tokenizer_pool_extra_config)
  426. else:
  427. tokenizer_pool_extra_config_parsed = (
  428. tokenizer_pool_extra_config or {})
  429. tokenizer_pool_config = cls(tokenizer_pool_size,
  430. tokenizer_pool_type,
  431. tokenizer_pool_extra_config_parsed)
  432. else:
  433. tokenizer_pool_config = None
  434. return tokenizer_pool_config
  435. class LoadFormat(str, enum.Enum):
  436. AUTO = "auto"
  437. PT = "pt"
  438. SAFETENSORS = "safetensors"
  439. NPCACHE = "npcache"
  440. DUMMY = "dummy"
  441. TENSORIZER = "tensorizer"
  442. @dataclass
  443. class LoadConfig:
  444. """
  445. download_dir: Directory to download and load the weights, default to the
  446. default cache directory of huggingface.
  447. load_format: The format of the model weights to load:
  448. "auto" will try to load the weights in the safetensors format and
  449. fall back to the pytorch bin format if safetensors format is
  450. not available.
  451. "pt" will load the weights in the pytorch bin format.
  452. "safetensors" will load the weights in the safetensors format.
  453. "npcache" will load the weights in pytorch format and store
  454. a numpy cache to speed up the loading.
  455. "dummy" will initialize the weights with random values, which is
  456. mainly for profiling.
  457. "tensorizer" will use CoreWeave's tensorizer library for
  458. fast weight loading.
  459. """
  460. load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
  461. download_dir: Optional[str] = None
  462. model_loader_extra_config: Optional[Union[str, dict]] = field(
  463. default_factory=dict)
  464. def __post_init__(self):
  465. model_loader_extra_config = self.model_loader_extra_config or {}
  466. if isinstance(model_loader_extra_config, str):
  467. self.model_loader_extra_config = json.loads(
  468. model_loader_extra_config)
  469. self._verify_load_format()
  470. def _verify_load_format(self) -> None:
  471. if not isinstance(self.load_format, str):
  472. return
  473. load_format = self.load_format.lower()
  474. self.load_format = LoadFormat(load_format)
  475. rocm_not_supported_load_format: List[str] = []
  476. if is_hip() and load_format in rocm_not_supported_load_format:
  477. rocm_supported_load_format = [
  478. f for f in LoadFormat.__members__
  479. if (f not in rocm_not_supported_load_format)
  480. ]
  481. raise ValueError(
  482. f"load format '{load_format}' is not supported in ROCm. "
  483. f"Supported load formats are "
  484. f"{rocm_supported_load_format}")
  485. class ParallelConfig:
  486. """Configuration for the distributed execution.
  487. Args:
  488. pipeline_parallel_size: Number of pipeline parallel groups.
  489. tensor_parallel_size: Number of tensor parallel groups.
  490. worker_use_ray: Whether to use Ray for model workers. Will be set to
  491. True if either pipeline_parallel_size or tensor_parallel_size is
  492. greater than 1.
  493. max_parallel_loading_workers: Maximum number of multiple batches
  494. when load model sequentially. To avoid RAM OOM when using tensor
  495. parallel and large models.
  496. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  497. fall back to NCCL.
  498. tokenizer_pool_config: Config for the tokenizer pool.
  499. If None, will use synchronous tokenization.
  500. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
  501. https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
  502. """
  503. def __init__(
  504. self,
  505. pipeline_parallel_size: int,
  506. tensor_parallel_size: int,
  507. worker_use_ray: bool,
  508. max_parallel_loading_workers: Optional[int] = None,
  509. disable_custom_all_reduce: bool = False,
  510. tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
  511. ray_workers_use_nsight: bool = False,
  512. placement_group: Optional["PlacementGroup"] = None,
  513. ) -> None:
  514. self.pipeline_parallel_size = pipeline_parallel_size
  515. self.tensor_parallel_size = tensor_parallel_size
  516. self.worker_use_ray = worker_use_ray
  517. self.max_parallel_loading_workers = max_parallel_loading_workers
  518. self.disable_custom_all_reduce = disable_custom_all_reduce
  519. self.tokenizer_pool_config = tokenizer_pool_config
  520. self.ray_workers_use_nsight = ray_workers_use_nsight
  521. self.placement_group = placement_group
  522. self.world_size = pipeline_parallel_size * self.tensor_parallel_size
  523. if self.world_size > 1:
  524. self.worker_use_ray = True
  525. self._verify_args()
  526. def _verify_args(self) -> None:
  527. if self.pipeline_parallel_size > 1:
  528. raise NotImplementedError(
  529. "Pipeline parallelism is not supported yet.")
  530. if not self.disable_custom_all_reduce and self.world_size > 1:
  531. if is_hip():
  532. self.disable_custom_all_reduce = True
  533. logger.info(
  534. "Disabled the custom all-reduce kernel because it is not "
  535. "supported on AMD GPUs.")
  536. elif self.pipeline_parallel_size > 1:
  537. self.disable_custom_all_reduce = True
  538. logger.info(
  539. "Disabled the custom all-reduce kernel because it is not "
  540. "supported with pipeline parallelism.")
  541. if self.ray_workers_use_nsight and not self.worker_use_ray:
  542. raise ValueError("Unable to use nsight profiling unless workers "
  543. "run with Ray.")
  544. class SchedulerConfig:
  545. """Scheduler configuration.
  546. Args:
  547. max_num_batched_tokens: Maximum number of tokens to be processed in
  548. a single iteration.
  549. max_num_seqs: Maximum number of sequences to be processed in a single
  550. iteration.
  551. max_model_len: Maximum length of a sequence (including prompt
  552. and generated text).
  553. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
  554. num_lookahead_slots: The number of slots to allocate per sequence per
  555. step, beyond the known token ids. This is used in speculative
  556. decoding to store KV activations of tokens which may or may not be
  557. accepted.
  558. delay_factor: Apply a delay (of delay factor multiplied by previous
  559. prompt latency) before scheduling next prompt.
  560. enable_chunked_prefill: If True, prefill requests can be chunked based
  561. on the remaining max_num_batched_tokens.
  562. """
  563. def __init__(
  564. self,
  565. max_num_batched_tokens: Optional[int],
  566. max_num_seqs: int,
  567. max_model_len: int,
  568. use_v2_block_manager: bool = False,
  569. num_lookahead_slots: int = 0,
  570. delay_factor: float = 0.0,
  571. enable_chunked_prefill: bool = False,
  572. ) -> None:
  573. if max_num_batched_tokens is not None:
  574. self.max_num_batched_tokens = max_num_batched_tokens
  575. else:
  576. if enable_chunked_prefill:
  577. # For chunked prefill, choose the well-tuned batch size.
  578. self.max_num_batched_tokens = 768
  579. else:
  580. # If max_model_len is too short, use 2048 as the default value
  581. # for higher throughput.
  582. self.max_num_batched_tokens = max(max_model_len, 2048)
  583. if enable_chunked_prefill:
  584. logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
  585. self.max_num_seqs = max_num_seqs
  586. self.max_model_len = max_model_len
  587. self.use_v2_block_manager = use_v2_block_manager
  588. self.num_lookahead_slots = num_lookahead_slots
  589. self.delay_factor = delay_factor
  590. self.chunked_prefill_enabled = enable_chunked_prefill
  591. self._verify_args()
  592. def _verify_args(self) -> None:
  593. if (self.max_num_batched_tokens < self.max_model_len
  594. and not self.chunked_prefill_enabled):
  595. raise ValueError(
  596. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  597. f"smaller than max_model_len ({self.max_model_len}). "
  598. "This effectively limits the maximum sequence length to "
  599. "max_num_batched_tokens and makes Aphrodite reject longer "
  600. "sequences. Please increase max_num_batched_tokens or "
  601. "decrease max_model_len.")
  602. if self.max_num_batched_tokens < self.max_num_seqs:
  603. raise ValueError(
  604. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  605. "be greater than or equal to max_num_seqs "
  606. f"({self.max_num_seqs}).")
  607. if self.num_lookahead_slots < 0:
  608. raise ValueError(
  609. "num_lookahead_slots "
  610. f"({self.num_lookahead_slots}) must be greater than or "
  611. "equal to 0.")
  612. class DeviceConfig:
  613. def __init__(self, device: str = "auto") -> None:
  614. if device == "auto":
  615. # Automated device type detection
  616. if is_neuron():
  617. self.device_type = "neuron"
  618. elif is_cpu():
  619. self.device_type = "cpu"
  620. else:
  621. # We don't call torch.cuda.is_available() here to
  622. # avoid initializing CUDA before workers are forked
  623. self.device_type = "cuda"
  624. else:
  625. # Device type is assigned explicitly
  626. self.device_type = device
  627. # Some device types require processing inputs on CPU
  628. if self.device_type in ["neuron"]:
  629. self.device = torch.device("cpu")
  630. else:
  631. # Set device with device type
  632. self.device = torch.device(self.device_type)
  633. class SpeculativeConfig:
  634. """Configuration for speculative decoding.
  635. The configuration is currently specialized to draft-model speculative
  636. decoding with top-1 proposals.
  637. """
  638. @staticmethod
  639. def maybe_create_spec_config(
  640. target_model_config: ModelConfig,
  641. target_parallel_config: ParallelConfig,
  642. target_dtype: str,
  643. speculative_model: Optional[str],
  644. num_speculative_tokens: Optional[int],
  645. speculative_max_model_len: Optional[int],
  646. enable_chunked_prefill: bool,
  647. use_v2_block_manager: bool,
  648. ) -> Optional["SpeculativeConfig"]:
  649. """Create a SpeculativeConfig if possible, else return None.
  650. This function attempts to create a SpeculativeConfig object based on the
  651. provided parameters. If the necessary conditions are met, it returns an
  652. instance of SpeculativeConfig. Otherwise, it returns None.
  653. Args:
  654. target_model_config (ModelConfig): The configuration of the target
  655. model.
  656. target_parallel_config (ParallelConfig): The parallel configuration
  657. for the target model.
  658. target_dtype (str): The data type used for the target model.
  659. speculative_model (Optional[str]): The name of the speculative
  660. model, if provided.
  661. num_speculative_tokens (Optional[int]): The number of speculative
  662. tokens, if provided.
  663. speculative_max_model_len (Optional[int]): The maximum model len of
  664. the speculative model. Used when testing the ability to skip
  665. speculation for some sequences.
  666. enable_chunked_prefill (bool): Whether Aphrodite is configured to
  667. use chunked prefill or not. Used for raising an error since its
  668. not yet compatible with spec decode.
  669. use_v2_block_manager (bool): Whether Aphrodite is configured to
  670. use the v2 block manager or not. Used for raising an error
  671. since the v2 block manager is required with spec decode.
  672. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
  673. window, if provided.
  674. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
  675. window, if provided.
  676. Returns:
  677. Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
  678. the necessary conditions are met, else None.
  679. """
  680. if (speculative_model is None and num_speculative_tokens is None):
  681. return None
  682. if speculative_model is not None and num_speculative_tokens is None:
  683. raise ValueError(
  684. "Expected both speculative_model and "
  685. "num_speculative_tokens to be provided, but found "
  686. f"{speculative_model=} and {num_speculative_tokens=}.")
  687. assert (speculative_model is not None
  688. and num_speculative_tokens is not None)
  689. if enable_chunked_prefill:
  690. raise ValueError(
  691. "Speculative decoding and chunked prefill are "
  692. f"currently mutually exclusive ({enable_chunked_prefill=}).")
  693. if not use_v2_block_manager:
  694. raise ValueError(
  695. "Speculative decoding requires usage of the V2 "
  696. "block manager. Enable it with --use-v2-block-manager.")
  697. # TODO: The user should be able to specify revision/quantization/max
  698. # model len for the draft model. It is not currently supported.
  699. draft_revision = None
  700. draft_code_revision = None
  701. draft_quantization = None
  702. draft_model_config = ModelConfig(
  703. model=speculative_model,
  704. tokenizer=target_model_config.tokenizer,
  705. tokenizer_mode=target_model_config.tokenizer_mode,
  706. trust_remote_code=target_model_config.trust_remote_code,
  707. dtype=target_model_config.dtype,
  708. seed=target_model_config.seed,
  709. revision=draft_revision,
  710. code_revision=draft_code_revision,
  711. tokenizer_revision=target_model_config.tokenizer_revision,
  712. max_model_len=None,
  713. quantization=draft_quantization,
  714. enforce_eager=target_model_config.enforce_eager,
  715. max_context_len_to_capture=target_model_config.
  716. max_context_len_to_capture,
  717. max_logprobs=target_model_config.max_logprobs,
  718. )
  719. draft_model_config.max_model_len = (
  720. SpeculativeConfig._maybe_override_draft_max_model_len(
  721. speculative_max_model_len,
  722. draft_model_config.max_model_len,
  723. target_model_config.max_model_len,
  724. ))
  725. draft_parallel_config = (
  726. SpeculativeConfig.create_draft_parallel_config(
  727. target_parallel_config))
  728. return SpeculativeConfig(
  729. draft_model_config,
  730. draft_parallel_config,
  731. num_speculative_tokens,
  732. )
  733. @staticmethod
  734. def _maybe_override_draft_max_model_len(
  735. speculative_max_model_len: Optional[int],
  736. draft_max_model_len: int,
  737. target_max_model_len: int,
  738. ) -> int:
  739. """Determine the max sequence len for the draft model. This is usually
  740. the draft_max_model_len, but may be the target_max_model_len if it is
  741. less than the draft_max_model_len, or may be speculative_max_model_len
  742. if it is specified.
  743. This is necessary so that sequences do not exceed the capacity of the
  744. draft model or the target model.
  745. speculative_max_model_len is mainly used for testing that sequences can
  746. skip speculation.
  747. """
  748. if speculative_max_model_len is not None:
  749. if speculative_max_model_len > draft_max_model_len:
  750. raise ValueError(f"{speculative_max_model_len=} cannot be "
  751. f"larger than {draft_max_model_len=}")
  752. if speculative_max_model_len > target_max_model_len:
  753. raise ValueError(f"{speculative_max_model_len=} cannot be "
  754. f"larger than {target_max_model_len=}")
  755. return speculative_max_model_len
  756. return min(
  757. draft_max_model_len,
  758. target_max_model_len,
  759. )
  760. @staticmethod
  761. def create_draft_parallel_config(
  762. target_parallel_config: ParallelConfig) -> ParallelConfig:
  763. """Create a parallel config for use by the draft worker.
  764. This is mostly a copy of the target parallel config. In the future the
  765. draft worker can have a different parallel strategy, e.g. TP=1.
  766. """
  767. draft_parallel_config = ParallelConfig(
  768. pipeline_parallel_size=target_parallel_config.
  769. pipeline_parallel_size,
  770. tensor_parallel_size=target_parallel_config.tensor_parallel_size,
  771. worker_use_ray=target_parallel_config.worker_use_ray,
  772. max_parallel_loading_workers=target_parallel_config.
  773. max_parallel_loading_workers,
  774. disable_custom_all_reduce=target_parallel_config.
  775. disable_custom_all_reduce,
  776. tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
  777. ray_workers_use_nsight=target_parallel_config.
  778. ray_workers_use_nsight,
  779. placement_group=target_parallel_config.placement_group,
  780. )
  781. return draft_parallel_config
  782. def __init__(
  783. self,
  784. draft_model_config: ModelConfig,
  785. draft_parallel_config: ParallelConfig,
  786. num_speculative_tokens: int,
  787. ):
  788. """Create a SpeculativeConfig object.
  789. Args:
  790. draft_model_config: ModelConfig for the draft model.
  791. draft_parallel_config: ParallelConfig for the draft model.
  792. num_speculative_tokens: The number of tokens to sample from the
  793. draft model before scoring with the target model.
  794. """
  795. self.draft_model_config = draft_model_config
  796. self.draft_parallel_config = draft_parallel_config
  797. self.num_speculative_tokens = num_speculative_tokens
  798. self._verify_args()
  799. def _verify_args(self) -> None:
  800. if self.num_speculative_tokens <= 0:
  801. raise ValueError("Expected num_speculative_tokens to be greater "
  802. f"than zero ({self.num_speculative_tokens}).")
  803. if self.draft_model_config:
  804. self.draft_model_config.verify_with_parallel_config(
  805. self.draft_parallel_config)
  806. @property
  807. def num_lookahead_slots(self) -> int:
  808. """The number of additional slots the scheduler should allocate per
  809. step, in addition to the slots allocated for each known token.
  810. This is equal to the number of speculative tokens, as each speculative
  811. token must be scored.
  812. """
  813. return self.num_speculative_tokens
  814. def __repr__(self) -> str:
  815. draft_model = self.draft_model_config.model
  816. num_spec_tokens = self.num_speculative_tokens
  817. return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
  818. @dataclass
  819. class LoRAConfig:
  820. max_lora_rank: int
  821. max_loras: int
  822. max_cpu_loras: Optional[int] = None
  823. lora_dtype: Optional[torch.dtype] = None
  824. lora_extra_vocab_size: int = 256
  825. # This is a constant.
  826. lora_vocab_padding_size: ClassVar[int] = 256
  827. def __post_init__(self):
  828. # Keep this in sync with kernels/punica/bgmv/bgmv_config.h
  829. possible_max_ranks = (8, 16, 32, 64)
  830. possible_lora_extra_vocab_size = (0, 256, 512)
  831. if self.max_lora_rank not in possible_max_ranks:
  832. raise ValueError(
  833. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  834. f"{possible_max_ranks}.")
  835. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  836. raise ValueError(
  837. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  838. f"must be one of {possible_lora_extra_vocab_size}.")
  839. if self.max_loras < 1:
  840. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  841. if self.max_cpu_loras is None:
  842. self.max_cpu_loras = self.max_loras
  843. elif self.max_cpu_loras < self.max_loras:
  844. raise ValueError(
  845. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  846. f"max_loras ({self.max_loras})")
  847. def verify_with_model_config(self, model_config: ModelConfig):
  848. if self.lora_dtype in (None, "auto"):
  849. self.lora_dtype = model_config.dtype
  850. elif isinstance(self.lora_dtype, str):
  851. self.lora_dtype = getattr(torch, self.lora_dtype)
  852. if model_config.quantization and model_config.quantization not in [
  853. "awq", "gptq"
  854. ]:
  855. # TODO support all other quants
  856. logger.warning(f"{model_config.quantization} quantization is not "
  857. "tested with LoRA yet.")
  858. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  859. if scheduler_config.max_num_batched_tokens > 65528:
  860. raise ValueError(
  861. "Due to limitations of the custom LoRA CUDA kernel, "
  862. "max_num_batched_tokens must be <= 65528 when "
  863. "LoRA is enabled.")
  864. @dataclass
  865. class VisionLanguageConfig:
  866. """Configs the input data format and how models should run for
  867. vision language models."""
  868. class ImageInputType(enum.Enum):
  869. """Image input type into the vision language model.
  870. An image roughly goes through the following transformation:
  871. Raw image --> pixel values --> image features --> image embeddings.
  872. The difference between different image input types is where the
  873. image encoder (pixel values --> image features) is run.
  874. Different image input types also correspond to different tensor shapes.
  875. For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
  876. IMAGE_FEATURES: (1, 576, 1024).
  877. """
  878. PIXEL_VALUES = enum.auto()
  879. IMAGE_FEATURES = enum.auto()
  880. image_input_type: ImageInputType
  881. # The input id corresponding to image token.
  882. image_token_id: int
  883. # Used for running `run_prefill_max_token`.
  884. # For models that support varying resolution, this corresponds to
  885. # worst case scenario (biggest supported resolution).
  886. image_input_shape: tuple
  887. image_feature_size: int
  888. @classmethod
  889. def get_image_input_enum_type(
  890. cls, value: str) -> "VisionLanguageConfig.ImageInputType":
  891. """Get the image input type from a string."""
  892. try:
  893. return cls.ImageInputType[value.upper()]
  894. except KeyError as e:
  895. raise ValueError(f"{value} is not a valid choice. "
  896. f"Expecting to choose from "
  897. f"{[x.name for x in cls.ImageInputType]}.") from e
  898. _STR_DTYPE_TO_TORCH_DTYPE = {
  899. "half": torch.float16,
  900. "float16": torch.float16,
  901. "float": torch.float32,
  902. "float32": torch.float32,
  903. "bfloat16": torch.bfloat16,
  904. }
  905. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  906. def _get_and_verify_dtype(
  907. config: PretrainedConfig,
  908. dtype: Union[str, torch.dtype],
  909. ) -> torch.dtype:
  910. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  911. # because config.torch_dtype can be None.
  912. config_dtype = getattr(config, "torch_dtype", None)
  913. if config_dtype is None:
  914. config_dtype = torch.float32
  915. if isinstance(dtype, str):
  916. dtype = dtype.lower()
  917. if dtype == "auto":
  918. if config_dtype == torch.float32:
  919. # Following the common practice, we use float16 for float32
  920. # models.
  921. torch_dtype = torch.float16
  922. else:
  923. torch_dtype = config_dtype
  924. else:
  925. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  926. raise ValueError(f"Unknown dtype: {dtype}")
  927. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  928. elif isinstance(dtype, torch.dtype):
  929. torch_dtype = dtype
  930. else:
  931. raise ValueError(f"Unknown dtype: {dtype}")
  932. if is_hip() and torch_dtype == torch.float32:
  933. rocm_supported_dtypes = [
  934. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  935. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  936. ]
  937. raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
  938. f"Supported dtypes are {rocm_supported_dtypes}")
  939. # Verify the dtype.
  940. if torch_dtype != config_dtype:
  941. if torch_dtype == torch.float32:
  942. # Upcasting to float32 is allowed.
  943. pass
  944. elif config_dtype == torch.float32:
  945. # Downcasting from float32 to float16 or bfloat16 is allowed.
  946. pass
  947. else:
  948. # Casting between float16 and bfloat16 is allowed with a warning.
  949. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  950. return torch_dtype
  951. def _get_and_verify_max_len(
  952. hf_config: PretrainedConfig,
  953. max_model_len: Optional[int],
  954. ) -> int:
  955. """Get and verify the model's maximum length."""
  956. derived_max_model_len = float("inf")
  957. possible_keys = [
  958. # Cohere: needs to prioritize this over "max_position_embeddings"
  959. "model_max_length",
  960. # OPT
  961. "max_position_embeddings",
  962. # GPT-2
  963. "n_positions",
  964. # MPT
  965. "max_seq_len",
  966. # ChatGLM2
  967. "seq_length",
  968. # Command-R
  969. "model_max_length",
  970. # Others
  971. "max_sequence_length",
  972. "max_seq_length",
  973. "seq_len",
  974. ]
  975. max_len_key = None
  976. for key in possible_keys:
  977. max_len = getattr(hf_config, key, None)
  978. if max_len is not None:
  979. max_len_key = key if max_len < derived_max_model_len \
  980. else max_len_key
  981. derived_max_model_len = min(derived_max_model_len, max_len)
  982. if derived_max_model_len == float("inf"):
  983. if max_model_len is not None:
  984. # If max_model_len is specified, we use it.
  985. return max_model_len
  986. default_max_len = 2048
  987. logger.warning(
  988. "The model's config.json does not contain any of the following "
  989. "keys to determine the original maximum length of the model: "
  990. f"{possible_keys}. Assuming the model's maximum length is "
  991. f"{default_max_len}.")
  992. derived_max_model_len = default_max_len
  993. rope_scaling = getattr(hf_config, "rope_scaling", None)
  994. if rope_scaling is not None and rope_scaling["type"] != "longrope" and \
  995. rope_scaling["type"] != "su":
  996. assert "factor" in rope_scaling
  997. scaling_factor = rope_scaling["factor"]
  998. if rope_scaling["type"] == "yarn":
  999. derived_max_model_len = rope_scaling[
  1000. "original_max_position_embeddings"]
  1001. derived_max_model_len *= scaling_factor
  1002. if max_model_len is None:
  1003. max_model_len = derived_max_model_len
  1004. elif max_model_len > derived_max_model_len:
  1005. # hope this works
  1006. scaling_factor = max_model_len / derived_max_model_len
  1007. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  1008. logger.warning(
  1009. f"User-specified max_model_len {max_model_len} is higher than "
  1010. f"the original {derived_max_model_len}. "
  1011. "Attempting to use RoPE scaling.")
  1012. derived_max_model_len = max_model_len
  1013. return int(max_model_len)
  1014. @dataclass
  1015. class DecodingConfig:
  1016. """Dataclass which contains the decoding strategy of the engine"""
  1017. # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
  1018. guided_decoding_backend: str = 'outlines'
  1019. def __post_init__(self):
  1020. valid_guided_backends = ['outlines', 'lm-format-enforcer']
  1021. backend = self.guided_decoding_backend
  1022. if backend not in valid_guided_backends:
  1023. raise ValueError(f"Invalid guided_decoding_backend '{backend},"
  1024. f"must be one of {valid_guided_backends}")
  1025. @dataclass(frozen=True)
  1026. class EngineConfig:
  1027. """Dataclass which contains all engine-related configuration. This
  1028. simplifies passing around the distinct configurations in the codebase.
  1029. """
  1030. model_config: ModelConfig
  1031. cache_config: CacheConfig
  1032. parallel_config: ParallelConfig
  1033. scheduler_config: SchedulerConfig
  1034. device_config: DeviceConfig
  1035. load_config: LoadConfig
  1036. lora_config: Optional[LoRAConfig]
  1037. vision_language_config: Optional[VisionLanguageConfig]
  1038. speculative_config: Optional[SpeculativeConfig]
  1039. decoding_config: Optional[DecodingConfig]
  1040. def __post_init__(self):
  1041. """Verify configs are valid & consistent with each other.
  1042. """
  1043. self.model_config.verify_with_parallel_config(self.parallel_config)
  1044. self.cache_config.verify_with_parallel_config(self.parallel_config)
  1045. if self.lora_config:
  1046. self.lora_config.verify_with_model_config(self.model_config)
  1047. self.lora_config.verify_with_scheduler_config(
  1048. self.scheduler_config)
  1049. def to_dict(self):
  1050. """Return the configs as a dictionary, for use in **kwargs.
  1051. """
  1052. return dict(
  1053. (field.name, getattr(self, field.name)) for field in fields(self))