config.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740
  1. import enum
  2. import json
  3. import os
  4. from dataclasses import dataclass, field, fields
  5. from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
  6. Type, Union)
  7. import torch
  8. from loguru import logger
  9. from transformers import PretrainedConfig
  10. from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
  11. cuda_device_count_stateless,
  12. get_cpu_memory, is_cpu, is_hip, is_neuron,
  13. is_openvino, is_xpu, print_warning_once)
  14. from aphrodite.distributed import get_current_tp_rank_partition_size
  15. from aphrodite.modeling.models import ModelRegistry
  16. from aphrodite.platforms import current_platform
  17. from aphrodite.quantization import QUANTIZATION_METHODS
  18. from aphrodite.transformers_utils.config import get_config, get_hf_text_config
  19. if TYPE_CHECKING:
  20. from ray.util.placement_group import PlacementGroup
  21. from aphrodite.executor.executor_base import ExecutorBase
  22. from aphrodite.modeling.model_loader.loader import BaseModelLoader
  23. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import ( # noqa: E501
  24. BaseTokenizerGroup)
  25. # If true, will load models from ModelScope instead of Hugging Face Hub.
  26. APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
  27. "False").lower() == "true"
  28. _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
  29. _PP_SUPPORTED_MODELS = [
  30. "AquilaModel",
  31. "AquilaForCausalLM",
  32. "InternLMForCausalLM",
  33. "LlamaForCausalLM",
  34. "LLaMAForCausalLM",
  35. "MistralForCausalLM",
  36. "Phi3ForCausalLM",
  37. "MixtralForCausalLM",
  38. "NemotronForCausalLM",
  39. "Qwen2ForCausalLM",
  40. "Qwen2MoeForCausalLM",
  41. ]
  42. _OPTIMIZED_QUANTS = [
  43. "fp8",
  44. "marlin",
  45. "gptq_marlin_24",
  46. "gptq_marlin",
  47. "awq_marlin",
  48. "fbgemm_fp8",
  49. "compressed-tensors",
  50. "compressed_tensors",
  51. ]
  52. class ModelConfig:
  53. """Configuration for the model.
  54. Args:
  55. model: Name or path of the huggingface model to use.
  56. It is also used as the content for `model_name` tag in metrics
  57. output when `served_model_name` is not specified.
  58. tokenizer: Name or path of the huggingface tokenizer to use.
  59. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  60. available, and "slow" will always use the slow tokenizer.
  61. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  62. downloading the model and tokenizer.
  63. dtype: Data type for model weights and activations. The "auto" option
  64. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  65. for BF16 models.
  66. seed: Random seed for reproducibility.
  67. revision: The specific model version to use. It can be a branch name,
  68. a tag name, or a commit id. If unspecified, will use the default
  69. version.
  70. code_revision: The specific revision to use for the model code on
  71. Hugging Face Hub. It can be a branch name, a tag name, or a
  72. commit id. If unspecified, will use the default version.
  73. rope_scaling: Dictionary containing the scaling configuration for the
  74. RoPE embeddings. When using this flag, don't update
  75. `max_position_embeddings` to the expected new maximum.
  76. tokenizer_revision: The specific tokenizer version to use. It can be a
  77. branch name, a tag name, or a commit id. If unspecified, will use
  78. the default version.
  79. max_model_len: Maximum length of a sequence (including prompt and
  80. output). If None, will be derived from the model.
  81. quantization: Quantization method that was used to quantize the model
  82. weights. If None, we assume the model weights are not quantized.
  83. deepspeed_fp_bits: Number of bits to use for DeepSpeed FP quantization.
  84. Supported number of bits are: 4, 6, 8, 12.
  85. quantization_param_path: Path to JSON file containing scaling factors.
  86. Used to load KV cache scaling factors into the model when KV cache
  87. type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
  88. be used to load activation and weight scaling factors when the
  89. model dtype is FP8_E4M3 on ROCm.
  90. enforce_eager: Whether to enforce eager execution. If True, we will
  91. disable CUDA graph and always execute the model in eager mode.
  92. If False, we will use CUDA graph and eager execution in hybrid.
  93. If None, the user did not specify, so default to False -
  94. except for encoder/decoder models, which currently require
  95. eager mode.
  96. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  97. When a sequence has context length larger than this, we fall back
  98. to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
  99. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
  100. When a sequence has context length larger than this, we fall back
  101. to eager mode
  102. disable_sliding_window: Whether to disable sliding window. If True,
  103. we will disable the sliding window functionality of the model.
  104. If the model does not support sliding window, this argument is
  105. ignored.
  106. skip_tokenizer_init: If true, skip initialization of tokenizer and
  107. detokenizer.
  108. served_model_name: The model name used in metrics tag `model_name`,
  109. matches the model name exposed via the APIs. If multiple model
  110. names provided, the first name will be used. If not specified,
  111. the model name will be the same as `model`.
  112. """
  113. def __init__(
  114. self,
  115. model: str,
  116. tokenizer: str,
  117. tokenizer_mode: str,
  118. trust_remote_code: bool,
  119. dtype: Union[str, torch.dtype],
  120. seed: int,
  121. revision: Optional[str] = None,
  122. code_revision: Optional[str] = None,
  123. rope_scaling: Optional[dict] = None,
  124. rope_theta: Optional[float] = None,
  125. tokenizer_revision: Optional[str] = None,
  126. max_model_len: Optional[int] = None,
  127. quantization: Optional[str] = None,
  128. deepspeed_fp_bits: Optional[int] = None,
  129. quantization_param_path: Optional[str] = None,
  130. enforce_eager: Optional[bool] = None,
  131. max_context_len_to_capture: Optional[int] = None,
  132. max_seq_len_to_capture: Optional[int] = None,
  133. max_logprobs: int = 5,
  134. disable_sliding_window: bool = False,
  135. skip_tokenizer_init: bool = False,
  136. served_model_name: Optional[Union[str, List[str]]] = None,
  137. multimodal_config: Optional["MultiModalConfig"] = None,
  138. ) -> None:
  139. self.model = model
  140. self.tokenizer = tokenizer
  141. self.tokenizer_mode = tokenizer_mode
  142. self.trust_remote_code = trust_remote_code
  143. self.seed = seed
  144. self.revision = revision
  145. self.code_revision = code_revision
  146. self.rope_scaling = rope_scaling
  147. self.rope_theta = rope_theta
  148. # The tokenizer version is consistent with the model version by default.
  149. if tokenizer_revision is None:
  150. self.tokenizer_revision = revision
  151. else:
  152. self.tokenizer_revision = tokenizer_revision
  153. self.quantization = quantization
  154. self.deepspeed_fp_bits = deepspeed_fp_bits
  155. self.quantization_param_path = quantization_param_path
  156. self.enforce_eager = enforce_eager
  157. self.max_context_len_to_capture = max_context_len_to_capture
  158. if self.max_context_len_to_capture is not None:
  159. raise ValueError("`max_context_len_to_capture` is deprecated. "
  160. "Use `max_seq_len_to_capture` instead.")
  161. self.max_seq_len_to_capture = (max_seq_len_to_capture
  162. or max_context_len_to_capture)
  163. self.max_logprobs = max_logprobs
  164. self.disable_sliding_window = disable_sliding_window
  165. self.skip_tokenizer_init = skip_tokenizer_init
  166. self.hf_config = get_config(self.model, trust_remote_code, revision,
  167. code_revision, rope_scaling, rope_theta)
  168. self.hf_text_config = get_hf_text_config(self.hf_config)
  169. self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
  170. # Choose a default enforce_eager value if the user did not specify
  171. # a value (enforce_eager is None)
  172. if getattr(self.hf_config, 'is_encoder_decoder', False):
  173. if self.enforce_eager is None:
  174. # *Only for encoder/decoder models* and
  175. # *only if enforce_eager is unset*, override
  176. # to enforce_eager=True
  177. #
  178. # Add a logger message since it is *somewhat* non-intuitive that
  179. # enforce_eager is True when the user has not specified its
  180. # value.
  181. logger.info("Forcing enforce_eager == True because "
  182. "enforce_eager setting was unspecified and "
  183. "CUDAGraph is not supported with encoder/ "
  184. "decoder models.")
  185. self.enforce_eager = True
  186. if not self.enforce_eager:
  187. # Eager mode explicitly disabled by user for an encoder/
  188. # decoder model; however CUDAGRAPH + encoder/decoder is
  189. # not currently supported
  190. raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
  191. elif self.enforce_eager is None:
  192. # *Only for decoder-only models*, enforce_eager
  193. # defaults to False if unset. This is intuitive
  194. # so no logging message needed.
  195. self.enforce_eager = False
  196. if (not self.disable_sliding_window
  197. and self.hf_text_config.model_type == "gemma2"
  198. and self.hf_text_config.sliding_window is not None):
  199. print_warning_once(
  200. "Gemma 2 uses sliding window attention for every odd layer, "
  201. "which is currently not supported by Aphrodite. Disabling "
  202. "sliding window and capping the max length to the sliding "
  203. f"window size ({self.hf_text_config.sliding_window}).")
  204. self.disable_sliding_window = True
  205. self.max_model_len = _get_and_verify_max_len(
  206. hf_config=self.hf_text_config,
  207. max_model_len=max_model_len,
  208. disable_sliding_window=self.disable_sliding_window,
  209. sliding_window_len=self.get_hf_config_sliding_window(),
  210. rope_scaling_arg=self.rope_scaling)
  211. self.served_model_name = get_served_model_name(model,
  212. served_model_name)
  213. self.multimodal_config = multimodal_config
  214. if not self.skip_tokenizer_init:
  215. self._verify_tokenizer_mode()
  216. self._verify_embedding_mode()
  217. self._verify_quantization()
  218. self._verify_cuda_graph()
  219. def _verify_tokenizer_mode(self) -> None:
  220. tokenizer_mode = self.tokenizer_mode.lower()
  221. if tokenizer_mode not in ["auto", "slow"]:
  222. raise ValueError(
  223. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  224. "either 'auto' or 'slow'.")
  225. self.tokenizer_mode = tokenizer_mode
  226. def _verify_embedding_mode(self) -> None:
  227. architectures = getattr(self.hf_config, "architectures", [])
  228. self.embedding_mode = any(
  229. ModelRegistry.is_embedding_model(arch) for arch in architectures)
  230. def _parse_quant_hf_config(self):
  231. quant_cfg = getattr(self.hf_config, "quantization_config", None)
  232. if quant_cfg is None:
  233. # compress-tensors uses a "compression_config" key
  234. quant_cfg = getattr(self.hf_config, "compression_config", None)
  235. return quant_cfg
  236. def _verify_quantization(self) -> None:
  237. supported_quantization = [*QUANTIZATION_METHODS]
  238. rocm_supported_quantization = ["gptq", "squeezellm"]
  239. tpu_supported_quantization = ["tpu_int8"]
  240. if self.quantization is not None:
  241. self.quantization = self.quantization.lower()
  242. # Parse quantization method from the HF model config, if available.
  243. quant_cfg = self._parse_quant_hf_config()
  244. if quant_cfg is not None:
  245. quant_method = quant_cfg.get("quant_method", "").lower()
  246. # Detect which checkpoint is it
  247. for _, method in QUANTIZATION_METHODS.items():
  248. quantization_override = method.override_quantization_method(
  249. quant_cfg, self.quantization)
  250. if quantization_override:
  251. quant_method = quantization_override
  252. self.quantization = quantization_override
  253. break
  254. # Verify quantization configurations.
  255. if self.quantization is None:
  256. self.quantization = quant_method
  257. elif self.quantization != quant_method:
  258. raise ValueError(
  259. "Quantization method specified in the model config "
  260. f"({quant_method}) does not match the quantization "
  261. f"method specified in the `quantization` argument "
  262. f"({self.quantization}).")
  263. if self.quantization == "deepspeedfp":
  264. gs = 32 if self.deepspeed_fp_bits == 4 else 128
  265. self.hf_config.quantization_config = {
  266. "bits": self.deepspeed_fp_bits,
  267. "group_size": int(os.environ.get("DEEPSPEED_GROUP_SIZE", gs)),
  268. "quant_method": "deepspeedfp"
  269. }
  270. if self.quantization is not None:
  271. if self.quantization not in supported_quantization:
  272. raise ValueError(
  273. f"Unknown quantization method: {self.quantization}. Must "
  274. f"be one of {supported_quantization}.")
  275. if is_hip(
  276. ) and self.quantization not in rocm_supported_quantization:
  277. raise ValueError(
  278. f"{self.quantization} quantization is currently not "
  279. "supported in ROCm.")
  280. if current_platform.is_tpu(
  281. ) and self.quantization not in tpu_supported_quantization:
  282. raise ValueError(
  283. f"{self.quantization} quantization is currently not "
  284. f"supported in TPU Backend.")
  285. if self.quantization not in _OPTIMIZED_QUANTS:
  286. logger.warning(
  287. f"{self.quantization} quantization is not fully "
  288. "optimized yet. The speed can be slower than "
  289. "non-quantized models.")
  290. if self.quantization == "deepspeedfp" and self.deepspeed_fp_bits \
  291. is None:
  292. raise ValueError(
  293. "deepspeed_fp_bits must be specified when using "
  294. "deepspeedfp quantization.")
  295. def _verify_cuda_graph(self) -> None:
  296. if self.max_seq_len_to_capture is None:
  297. self.max_seq_len_to_capture = self.max_model_len
  298. self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
  299. self.max_model_len)
  300. def verify_with_parallel_config(
  301. self,
  302. parallel_config: "ParallelConfig",
  303. ) -> None:
  304. total_num_attention_heads = getattr(self.hf_text_config,
  305. "num_attention_heads", 0)
  306. tensor_parallel_size = parallel_config.tensor_parallel_size
  307. if (total_num_attention_heads % tensor_parallel_size != 0
  308. and self.quantization is not None):
  309. raise ValueError(
  310. f"Total number of attention heads "
  311. f"({total_num_attention_heads})"
  312. " must be divisible by tensor parallel size "
  313. f"({tensor_parallel_size}) when quantization is used.")
  314. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  315. architectures = getattr(self.hf_config, "architectures", [])
  316. if not all(arch in _PP_SUPPORTED_MODELS
  317. for arch in architectures) and pipeline_parallel_size > 1:
  318. raise NotImplementedError(
  319. "Pipeline parallelism is only supported for the following "
  320. f" architectures: {_PP_SUPPORTED_MODELS}.")
  321. if self.quantization == "bitsandbytes" and (
  322. parallel_config.tensor_parallel_size > 1
  323. or parallel_config.pipeline_parallel_size > 1):
  324. raise ValueError(
  325. "BitsAndBytes quantization with TP/PP is not supported yet.")
  326. if self.quantization == "bitsandbytes" and self.enforce_eager is False:
  327. raise ValueError(
  328. "BitsAndBytes with enforce_eager=False is not supported yet.")
  329. def is_attention_free(self) -> bool:
  330. """Returns True if the model has no attention, i.e. the model has no
  331. state that grows with the size of the context.
  332. """
  333. # Return true if the model is mamba.
  334. # This check should be augmented with more models in the future,
  335. # and made more robust if possible.
  336. if hasattr(self.hf_text_config,
  337. "model_type") and self.hf_text_config.model_type == 'mamba':
  338. return True
  339. return False
  340. def get_hf_config_sliding_window(self) -> Optional[int]:
  341. """Get the sliding window size, or None if disabled.
  342. """
  343. # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
  344. # addition to sliding window size. We check if that field is present
  345. # and if it's False, return None.
  346. if (hasattr(self.hf_text_config, "use_sliding_window")
  347. and not self.hf_text_config.use_sliding_window):
  348. return None
  349. return getattr(self.hf_text_config, "sliding_window", None)
  350. def get_sliding_window(self) -> Optional[int]:
  351. """Get the sliding window size, or None if disabled.
  352. """
  353. # If user disables sliding window, return None.
  354. if self.disable_sliding_window:
  355. return None
  356. # Otherwise get the value from the hf config.
  357. return self.get_hf_config_sliding_window()
  358. def get_vocab_size(self) -> int:
  359. return self.hf_text_config.vocab_size
  360. def get_hidden_size(self) -> int:
  361. return self.hf_text_config.hidden_size
  362. def get_head_size(self) -> int:
  363. # TODO remove hard code
  364. if hasattr(self.hf_text_config, "model_type"
  365. ) and self.hf_text_config.model_type == 'deepseek_v2':
  366. # FlashAttention supports only head_size 32, 64, 128, 256,
  367. # we need to pad head_size 192 to 256
  368. return 256
  369. if self.is_attention_free():
  370. return 0
  371. if hasattr(self.hf_text_config, "head_dim"):
  372. return self.hf_text_config.head_dim
  373. # FIXME: This may not be true for all models.
  374. return (self.hf_text_config.hidden_size //
  375. self.hf_text_config.num_attention_heads)
  376. def get_total_num_kv_heads(self) -> int:
  377. """Returns the total number of KV heads."""
  378. # For GPTBigCode & Falcon:
  379. # NOTE: for falcon, when new_decoder_architecture is True, the
  380. # multi_query flag is ignored and we use n_head_kv for the number of
  381. # KV heads.
  382. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  383. new_decoder_arch_falcon = (
  384. self.hf_config.model_type in falcon_model_types
  385. and getattr(self.hf_config, "new_decoder_architecture", False))
  386. if not new_decoder_arch_falcon and getattr(self.hf_text_config,
  387. "multi_query", False):
  388. # Multi-query attention, only one KV head.
  389. # Currently, tensor parallelism is not supported in this case.
  390. return 1
  391. # For DBRX and MPT
  392. if self.hf_config.model_type == "mpt":
  393. if "kv_n_heads" in self.hf_config.attn_config:
  394. return self.hf_config.attn_config["kv_n_heads"]
  395. return self.hf_config.num_attention_heads
  396. if self.hf_config.model_type == "dbrx":
  397. return getattr(self.hf_config.attn_config, "kv_n_heads",
  398. self.hf_config.num_attention_heads)
  399. if self.is_attention_free():
  400. return 0
  401. attributes = [
  402. # For Falcon:
  403. "n_head_kv",
  404. "num_kv_heads",
  405. # For LLaMA-2:
  406. "num_key_value_heads",
  407. # For ChatGLM:
  408. "multi_query_group_num",
  409. ]
  410. for attr in attributes:
  411. num_kv_heads = getattr(self.hf_text_config, attr, None)
  412. if num_kv_heads is not None:
  413. return num_kv_heads
  414. # For non-grouped-query attention models, the number of KV heads is
  415. # equal to the number of attention heads.
  416. return self.hf_text_config.num_attention_heads
  417. def get_num_kv_heads(self,
  418. parallel_config: "ParallelConfig",
  419. tp_rank: int = 0) -> int:
  420. """Returns the number of KV heads per GPU."""
  421. total_num_kv_heads = self.get_total_num_kv_heads()
  422. # If tensor parallelism is used, we divide the number of KV heads by
  423. # the tensor parallel size. We will replicate the KV heads in the
  424. # case where the number of KV heads is smaller than the tensor
  425. # parallel size so each GPU has at least one KV head.
  426. result = get_current_tp_rank_partition_size(
  427. total_num_kv_heads, tp_rank, parallel_config.tensor_parallel_size)
  428. return max(1, result)
  429. def get_num_attention_heads(self,
  430. parallel_config: "ParallelConfig",
  431. tp_rank: int = 0) -> int:
  432. if getattr(self.hf_text_config, "num_attention_heads", None) is None:
  433. return 0
  434. num_total_kv_heads = self.get_total_num_kv_heads()
  435. num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank)
  436. num_total_attention_heads = self.hf_text_config.num_attention_heads
  437. num_heads_per_kv_head = num_total_attention_heads // num_total_kv_heads
  438. # For GQA attention we make sure the whole attention head group is
  439. # together on the same GPU.
  440. return num_kv_heads * num_heads_per_kv_head
  441. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  442. from aphrodite.distributed.utils import get_pp_indices
  443. total_num_hidden_layers = getattr(self.hf_text_config,
  444. "num_hidden_layers", 0)
  445. pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
  446. pp_size = parallel_config.pipeline_parallel_size
  447. start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
  448. return end - start
  449. def contains_seqlen_agnostic_layers(
  450. self, parallel_config: "ParallelConfig") -> bool:
  451. """True for Mamba/SSM models (Jamba)"""
  452. return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
  453. def get_layers_block_type(self,
  454. parallel_config: "ParallelConfig") -> List[str]:
  455. num_layers = self.get_num_layers(parallel_config)
  456. if self.is_attention_free():
  457. assert (self.hf_config.model_type == "mamba")
  458. return ["mamba"] * num_layers
  459. # Transformers supports layers_block_type @property
  460. return getattr(self.hf_config, "layers_block_type",
  461. ["attention"] * num_layers)
  462. def get_num_attention_layers(self,
  463. parallel_config: "ParallelConfig") -> int:
  464. return len([
  465. t for t in self.get_layers_block_type(parallel_config)
  466. if t == "attention"
  467. ])
  468. def _get_num_seqlen_agnostic_layers(
  469. self, parallel_config: "ParallelConfig") -> int:
  470. return len([
  471. t for t in self.get_layers_block_type(parallel_config)
  472. if t != "attention"
  473. ])
  474. @property
  475. def is_encoder_decoder_model(self) -> bool:
  476. """Extract the HF encoder/decoder model flag."""
  477. return getattr(self.hf_config, "is_encoder_decoder", False)
  478. @property
  479. def is_embedding_model(self) -> bool:
  480. """Extract the embedding model flag."""
  481. return self.embedding_mode
  482. class CacheConfig:
  483. """Configuration for the KV cache.
  484. Args:
  485. block_size: Size of a cache block in number of tokens.
  486. gpu_memory_utilization: Fraction of GPU memory to use for the
  487. Aphrodite execution.
  488. swap_space: Size of the CPU swap space per GPU (in GiB).
  489. cache_dtype: Data type for kv cache storage.
  490. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
  491. profiled num_gpu_blocks if specified. Does nothing if None.
  492. """
  493. def __init__(
  494. self,
  495. block_size: int,
  496. gpu_memory_utilization: float,
  497. swap_space: float,
  498. cache_dtype: str,
  499. is_attention_free: bool,
  500. num_gpu_blocks_override: Optional[int] = None,
  501. sliding_window: Optional[int] = None,
  502. enable_prefix_caching: bool = False,
  503. cpu_offload_gb: float = 0.0,
  504. ) -> None:
  505. self.block_size = block_size
  506. self.gpu_memory_utilization = gpu_memory_utilization
  507. self.swap_space_bytes = swap_space * GiB_bytes
  508. self.num_gpu_blocks_override = num_gpu_blocks_override
  509. self.cache_dtype = cache_dtype
  510. self.is_attention_free = is_attention_free
  511. self.sliding_window = sliding_window
  512. self.enable_prefix_caching = enable_prefix_caching
  513. self.cpu_offload_gb = cpu_offload_gb
  514. self._verify_args()
  515. self._verify_cache_dtype()
  516. self._verify_prefix_caching()
  517. # Will be set after profiling.
  518. self.num_gpu_blocks = None
  519. self.num_cpu_blocks = None
  520. def metrics_info(self):
  521. # convert cache_config to dict(key: str, value: str) for prometheus
  522. # metrics info
  523. return {key: str(value) for key, value in self.__dict__.items()}
  524. def _verify_args(self) -> None:
  525. if self.gpu_memory_utilization > 1.0:
  526. raise ValueError(
  527. "GPU memory utilization must be less than 1.0. Got "
  528. f"{self.gpu_memory_utilization}.")
  529. def _verify_cache_dtype(self) -> None:
  530. if self.cache_dtype == "auto":
  531. pass
  532. elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
  533. logger.info(
  534. "Using fp8 data type to store kv cache. It reduces the GPU "
  535. "memory footprint and boosts the performance. "
  536. "Meanwhile, it may cause accuracy drop without a proper "
  537. "scaling factor")
  538. else:
  539. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  540. def _verify_prefix_caching(self) -> None:
  541. if not self.enable_prefix_caching:
  542. return
  543. if self.sliding_window is not None:
  544. raise NotImplementedError(
  545. "Prefix caching is not supported with sliding window. "
  546. "Run with --disable-sliding-window to use prefix caching.")
  547. if self.cache_dtype == "fp8":
  548. capability = current_platform.get_device_capability()
  549. capability = capability[0] * 10 + capability[1]
  550. if capability < 89:
  551. raise NotImplementedError(
  552. "FP8 KV cache with prefix caching is only supported on "
  553. "GPUs with compute capability 8.9 or higher (e.g., "
  554. "4090, H100). Your GPU has compute capability "
  555. f"{capability}")
  556. def verify_with_parallel_config(
  557. self,
  558. parallel_config: "ParallelConfig",
  559. ) -> None:
  560. total_cpu_memory = get_cpu_memory()
  561. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  562. # group are in the same node. However, the GPUs may span multiple nodes.
  563. num_gpus_per_node = parallel_config.tensor_parallel_size
  564. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  565. msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
  566. f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
  567. "is allocated for the swap space.")
  568. if cpu_memory_usage > 0.7 * total_cpu_memory:
  569. raise ValueError("Too large swap space. " + msg)
  570. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  571. logger.warning("Possibly too large swap space. " + msg)
  572. @dataclass
  573. class TokenizerPoolConfig:
  574. """Configuration for the tokenizer pool.
  575. Args:
  576. pool_size: Number of tokenizer workers in the pool.
  577. pool_type: Type of the pool.
  578. extra_config: Additional config for the pool.
  579. The way the config will be used depends on the
  580. pool type.
  581. """
  582. pool_size: int
  583. pool_type: Union[str, Type["BaseTokenizerGroup"]]
  584. extra_config: dict
  585. def __post_init__(self):
  586. if self.pool_type not in ("ray", ) and not isinstance(
  587. self.pool_type, type):
  588. raise ValueError(f"Unknown pool type: {self.pool_type}")
  589. if not isinstance(self.extra_config, dict):
  590. raise ValueError("extra_config must be a dictionary.")
  591. @classmethod
  592. def create_config(
  593. cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
  594. tokenizer_pool_extra_config: Optional[Union[str, dict]]
  595. ) -> Optional["TokenizerPoolConfig"]:
  596. """Create a TokenizerPoolConfig from the given parameters.
  597. If tokenizer_pool_size is 0, return None.
  598. Args:
  599. tokenizer_pool_size: Number of tokenizer workers in the pool.
  600. tokenizer_pool_type: Type of the pool.
  601. tokenizer_pool_extra_config: Additional config for the pool.
  602. The way the config will be used depends on the
  603. pool type. This can be a JSON string (will be parsed).
  604. """
  605. if tokenizer_pool_size:
  606. if isinstance(tokenizer_pool_extra_config, str):
  607. tokenizer_pool_extra_config_parsed = json.loads(
  608. tokenizer_pool_extra_config)
  609. else:
  610. tokenizer_pool_extra_config_parsed = (
  611. tokenizer_pool_extra_config or {})
  612. tokenizer_pool_config = cls(tokenizer_pool_size,
  613. tokenizer_pool_type,
  614. tokenizer_pool_extra_config_parsed)
  615. else:
  616. tokenizer_pool_config = None
  617. return tokenizer_pool_config
  618. class LoadFormat(str, enum.Enum):
  619. AUTO = "auto"
  620. PT = "pt"
  621. SAFETENSORS = "safetensors"
  622. NPCACHE = "npcache"
  623. DUMMY = "dummy"
  624. TENSORIZER = "tensorizer"
  625. SHARDED_STATE = "sharded_state"
  626. GGUF = "gguf"
  627. BITSANDBYTES = "bitsandbytes"
  628. @dataclass
  629. class LoadConfig:
  630. """
  631. download_dir: Directory to download and load the weights, default to the
  632. default cache directory of huggingface.
  633. load_format: The format of the model weights to load:
  634. "auto" will try to load the weights in the safetensors format and
  635. fall back to the pytorch bin format if safetensors format is
  636. not available.
  637. "pt" will load the weights in the pytorch bin format.
  638. "safetensors" will load the weights in the safetensors format.
  639. "npcache" will load the weights in pytorch format and store
  640. a numpy cache to speed up the loading.
  641. "dummy" will initialize the weights with random values, which is
  642. mainly for profiling.
  643. "tensorizer" will use CoreWeave's tensorizer library for
  644. fast weight loading.
  645. ignore_patterns: The list of patterns to ignore when loading the model.
  646. Default to "original/**/*" to avoid repeated loading of llama's
  647. checkpoints.
  648. """
  649. load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
  650. download_dir: Optional[str] = None
  651. model_loader_extra_config: Optional[Union[str, dict]] = field(
  652. default_factory=dict)
  653. ignore_patterns: Optional[Union[List[str], str]] = None
  654. def __post_init__(self):
  655. model_loader_extra_config = self.model_loader_extra_config or {}
  656. if isinstance(model_loader_extra_config, str):
  657. self.model_loader_extra_config = json.loads(
  658. model_loader_extra_config)
  659. self._verify_load_format()
  660. if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
  661. logger.info(
  662. "Ignoring the following patterns when downloading weights: "
  663. f"{self.ignore_patterns}")
  664. else:
  665. self.ignore_patterns = ["original/**/*", "consolidated*"]
  666. def _verify_load_format(self) -> None:
  667. if not isinstance(self.load_format, str):
  668. return
  669. load_format = self.load_format.lower()
  670. self.load_format = LoadFormat(load_format)
  671. rocm_not_supported_load_format: List[str] = []
  672. if is_hip() and load_format in rocm_not_supported_load_format:
  673. rocm_supported_load_format = [
  674. f for f in LoadFormat.__members__
  675. if (f not in rocm_not_supported_load_format)
  676. ]
  677. raise ValueError(
  678. f"load format '{load_format}' is not supported in ROCm. "
  679. f"Supported load formats are "
  680. f"{rocm_supported_load_format}")
  681. class ParallelConfig:
  682. """Configuration for the distributed execution.
  683. Args:
  684. pipeline_parallel_size: Number of pipeline parallel groups.
  685. tensor_parallel_size: Number of tensor parallel groups.
  686. worker_use_ray: Deprecated, use distributed_executor_backend instead.
  687. max_parallel_loading_workers: Maximum number of multiple batches
  688. when load model sequentially. To avoid RAM OOM when using tensor
  689. parallel and large models.
  690. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  691. fall back to NCCL.
  692. tokenizer_pool_config: Config for the tokenizer pool.
  693. If None, will use synchronous tokenization.
  694. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
  695. https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
  696. placement_group: ray distributed model workers placement group.
  697. distributed_executor_backend: Backend to use for distributed model
  698. workers, either "ray" or "mp" (multiprocessing). If either
  699. pipeline_parallel_size or tensor_parallel_size is greater than 1,
  700. will default to "ray" if Ray is installed or "mp" otherwise.
  701. """
  702. def __init__(
  703. self,
  704. pipeline_parallel_size: int,
  705. tensor_parallel_size: int,
  706. worker_use_ray: Optional[bool] = None,
  707. max_parallel_loading_workers: Optional[int] = None,
  708. disable_custom_all_reduce: bool = False,
  709. tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
  710. ray_workers_use_nsight: bool = False,
  711. placement_group: Optional["PlacementGroup"] = None,
  712. distributed_executor_backend: Optional[Union[
  713. str, Type["ExecutorBase"]]] = None,
  714. ) -> None:
  715. self.pipeline_parallel_size = pipeline_parallel_size
  716. self.tensor_parallel_size = tensor_parallel_size
  717. self.distributed_executor_backend = distributed_executor_backend
  718. self.max_parallel_loading_workers = max_parallel_loading_workers
  719. self.disable_custom_all_reduce = disable_custom_all_reduce
  720. self.tokenizer_pool_config = tokenizer_pool_config
  721. self.ray_workers_use_nsight = ray_workers_use_nsight
  722. self.placement_group = placement_group
  723. self.world_size = pipeline_parallel_size * self.tensor_parallel_size
  724. if worker_use_ray:
  725. if self.distributed_executor_backend is None:
  726. self.distributed_executor_backend = "ray"
  727. elif not self.use_ray:
  728. raise ValueError(f"worker-use-ray can't be used with "
  729. f"distributed executor backend "
  730. f"'{self.distributed_executor_backend}'.")
  731. if self.distributed_executor_backend is None and self.world_size > 1:
  732. # We use multiprocessing by default if world_size fits on the
  733. # current node and we aren't in a ray placement group.
  734. from aphrodite.executor import ray_utils
  735. backend = "mp"
  736. ray_found = ray_utils.ray_is_available()
  737. if cuda_device_count_stateless() < self.world_size:
  738. if not ray_found:
  739. raise ValueError("Unable to load Ray which is "
  740. "required for multi-node inference, "
  741. "please install Ray with `pip install "
  742. "ray`.") from ray_utils.ray_import_err
  743. backend = "ray"
  744. elif ray_found:
  745. if self.placement_group:
  746. backend = "ray"
  747. else:
  748. from ray import is_initialized as ray_is_initialized
  749. if ray_is_initialized():
  750. from ray.util import get_current_placement_group
  751. if get_current_placement_group():
  752. backend = "ray"
  753. self.distributed_executor_backend = backend
  754. logger.info(
  755. f"Defaulting to use {backend} for distributed inference.")
  756. self._verify_args()
  757. self.rank = 0
  758. @property
  759. def use_ray(self) -> bool:
  760. return self.distributed_executor_backend == "ray" or (
  761. isinstance(self.distributed_executor_backend, type)
  762. and self.distributed_executor_backend.uses_ray)
  763. def _verify_args(self) -> None:
  764. # Lazy import to avoid circular import
  765. from aphrodite.executor.executor_base import ExecutorBase
  766. if self.distributed_executor_backend not in (
  767. "ray", "mp", None) and not (isinstance(
  768. self.distributed_executor_backend, type) and issubclass(
  769. self.distributed_executor_backend, ExecutorBase)):
  770. raise ValueError(
  771. "Unrecognized distributed executor backend "
  772. f"{self.distributed_executor_backend}. Supported "
  773. "values are 'ray', 'mp' or custom ExecutorBase subclass.")
  774. if self.use_ray:
  775. from aphrodite.executor import ray_utils
  776. ray_utils.assert_ray_available()
  777. if is_hip():
  778. self.disable_custom_all_reduce = True
  779. logger.info(
  780. "Disabled the custom all-reduce kernel because it is not "
  781. "supported on AMD GPUs.")
  782. if self.ray_workers_use_nsight and not self.use_ray:
  783. raise ValueError("Unable to use nsight profiling unless workers "
  784. "run with Ray.")
  785. class SchedulerConfig:
  786. """Scheduler configuration.
  787. Args:
  788. max_num_batched_tokens: Maximum number of tokens to be processed in
  789. a single iteration.
  790. max_num_seqs: Maximum number of sequences to be processed in a single
  791. iteration.
  792. max_model_len: Maximum length of a sequence (including prompt
  793. and generated text).
  794. is_attention_free: True if the running model does not have state that
  795. grows as the context size increases.
  796. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
  797. num_lookahead_slots: The number of slots to allocate per sequence per
  798. step, beyond the known token ids. This is used in speculative
  799. decoding to store KV activations of tokens which may or may not be
  800. accepted.
  801. delay_factor: Apply a delay (of delay factor multiplied by previous
  802. prompt latency) before scheduling next prompt.
  803. enable_chunked_prefill: If True, prefill requests can be chunked based
  804. on the remaining max_num_batched_tokens.
  805. embedding_mode: Whether the running model is for embedding.
  806. preemption_mode: Whether to perform preemption by swapping or
  807. recomputation. If not specified, we determine the mode as follows:
  808. We use recomputation by default since it incurs lower overhead than
  809. swapping. However, when the sequence group has multiple sequences
  810. (e.g., beam search), recomputation is not currently supported. In
  811. such a case, we use swapping instead.
  812. """
  813. def __init__(self,
  814. max_num_batched_tokens: Optional[int],
  815. max_num_seqs: int,
  816. max_model_len: int,
  817. is_attention_free: bool,
  818. use_v2_block_manager: bool = False,
  819. num_lookahead_slots: int = 0,
  820. delay_factor: float = 0.0,
  821. enable_chunked_prefill: bool = False,
  822. embedding_mode: Optional[bool] = False,
  823. preemption_mode: Optional[str] = None) -> None:
  824. if max_num_batched_tokens is not None:
  825. self.max_num_batched_tokens = max_num_batched_tokens
  826. else:
  827. if enable_chunked_prefill:
  828. # For chunked prefill, choose the well-tuned batch size.
  829. self.max_num_batched_tokens = 768
  830. elif embedding_mode:
  831. # For embedding, choose specific value for higher throughput
  832. self.max_num_batched_tokens = max(
  833. max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
  834. else:
  835. # If max_model_len is too short, use 2048 as the default value
  836. # for higher throughput.
  837. self.max_num_batched_tokens = max(max_model_len, 2048)
  838. if enable_chunked_prefill:
  839. logger.info(
  840. "Chunked prefill is enabled with "
  841. f"max_num_batched_tokens={self.max_num_batched_tokens}.")
  842. self.max_num_seqs = max_num_seqs
  843. self.max_model_len = max_model_len
  844. self.is_attention_free = is_attention_free
  845. self.use_v2_block_manager = use_v2_block_manager
  846. self.num_lookahead_slots = num_lookahead_slots
  847. self.delay_factor = delay_factor
  848. self.chunked_prefill_enabled = enable_chunked_prefill
  849. self.embedding_mode = embedding_mode
  850. self.preemption_mode = preemption_mode
  851. self._verify_args()
  852. def _verify_args(self) -> None:
  853. if (self.max_num_batched_tokens < self.max_model_len
  854. and not self.chunked_prefill_enabled):
  855. raise ValueError(
  856. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  857. f"smaller than max_model_len ({self.max_model_len}). "
  858. "This effectively limits the maximum sequence length to "
  859. "max_num_batched_tokens and makes Aphrodite reject longer "
  860. "sequences. Please increase max_num_batched_tokens or "
  861. "decrease max_model_len.")
  862. if self.max_num_batched_tokens < self.max_num_seqs:
  863. raise ValueError(
  864. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  865. "be greater than or equal to max_num_seqs "
  866. f"({self.max_num_seqs}).")
  867. if self.num_lookahead_slots < 0:
  868. raise ValueError(
  869. "num_lookahead_slots "
  870. f"({self.num_lookahead_slots}) must be greater than or "
  871. "equal to 0.")
  872. class DeviceConfig:
  873. def __init__(self, device: str = "auto") -> None:
  874. if device == "auto":
  875. # Automated device type detection
  876. if is_neuron():
  877. self.device_type = "neuron"
  878. elif is_openvino():
  879. self.device_type = "openvino"
  880. elif current_platform.is_tpu():
  881. self.device_type = "tpu"
  882. elif is_cpu():
  883. self.device_type = "cpu"
  884. elif is_xpu():
  885. self.device_type = "xpu"
  886. else:
  887. # We don't call torch.cuda.is_available() here to
  888. # avoid initializing CUDA before workers are forked
  889. self.device_type = "cuda"
  890. else:
  891. # Device type is assigned explicitly
  892. self.device_type = device
  893. # Some device types require processing inputs on CPU
  894. if self.device_type in ["neuron", "openvino"]:
  895. self.device = torch.device("cpu")
  896. elif self.device_type in ["tpu"]:
  897. self.device = None
  898. else:
  899. # Set device with device type
  900. self.device = torch.device(self.device_type)
  901. class SpeculativeConfig:
  902. """Configuration for speculative decoding.
  903. The configuration is currently specialized to draft-model speculative
  904. decoding with top-1 proposals.
  905. """
  906. @staticmethod
  907. def maybe_create_spec_config(
  908. target_model_config: ModelConfig,
  909. target_parallel_config: ParallelConfig,
  910. target_dtype: str,
  911. speculative_model: Optional[str],
  912. speculative_draft_tensor_parallel_size: Optional[int],
  913. num_speculative_tokens: Optional[int],
  914. speculative_max_model_len: Optional[int],
  915. enable_chunked_prefill: bool,
  916. use_v2_block_manager: bool,
  917. disable_log_stats: bool,
  918. speculative_disable_by_batch_size: Optional[int],
  919. ngram_prompt_lookup_max: Optional[int],
  920. ngram_prompt_lookup_min: Optional[int],
  921. draft_token_acceptance_method: str,
  922. typical_acceptance_sampler_posterior_threshold: Optional[float],
  923. typical_acceptance_sampler_posterior_alpha: Optional[float],
  924. disable_logprobs: Optional[bool],
  925. ) -> Optional["SpeculativeConfig"]:
  926. """Create a SpeculativeConfig if possible, else return None.
  927. This function attempts to create a SpeculativeConfig object based on the
  928. provided parameters. If the necessary conditions are met, it returns an
  929. instance of SpeculativeConfig. Otherwise, it returns None.
  930. Args:
  931. target_model_config (ModelConfig): The configuration of the target
  932. model.
  933. target_parallel_config (ParallelConfig): The parallel configuration
  934. for the target model.
  935. target_dtype (str): The data type used for the target model.
  936. speculative_model (Optional[str]): The name of the speculative
  937. model, if provided.
  938. num_speculative_tokens (Optional[int]): The number of speculative
  939. tokens, if provided. Will default to the number in the draft
  940. model config if present, otherwise is required.
  941. speculative_draft_tensor_parallel_size (Optional[int]): The degree
  942. of the tensor parallelism for the draft model.
  943. speculative_max_model_len (Optional[int]): The maximum model len of
  944. the speculative model. Used when testing the ability to skip
  945. speculation for some sequences.
  946. enable_chunked_prefill (bool): Whether Aphrodite is configured to
  947. use chunked prefill or not. Used for raising an error since its
  948. not yet compatible with spec decode.
  949. use_v2_block_manager (bool): Whether Aphrodite is configured to
  950. use the v2 block manager or not. Used for raising an error
  951. since the v2 block manager is required with spec decode.
  952. speculative_disable_by_batch_size (Optional[int]): Disable
  953. speculative decoding for new incoming requests when the number
  954. of enqueue requests is larger than this value, if provided.
  955. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
  956. window, if provided.
  957. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
  958. window, if provided.
  959. draft_token_acceptance_method (str): The method to use for
  960. accepting draft tokens. This can take two possible
  961. values 'rejection_sampler' and 'typical_acceptance_sampler'
  962. for RejectionSampler and TypicalAcceptanceSampler
  963. respectively.
  964. typical_acceptance_sampler_posterior_threshold (Optional[float]):
  965. A threshold value that sets a lower bound on the posterior
  966. probability of a token in the target model for it to be
  967. accepted. This threshold is used only when we use the
  968. TypicalAcceptanceSampler for token acceptance.
  969. typical_acceptance_sampler_posterior_alpha (Optional[float]):
  970. A scaling factor for the entropy-based threshold in the
  971. TypicalAcceptanceSampler.
  972. disable_logprobs (Optional[bool]): If set to True, token log
  973. probabilities are not returned during speculative decoding.
  974. If set to False, token log probabilities are returned
  975. according to the log probability settings in SamplingParams.
  976. If not specified, it defaults to True.
  977. Returns:
  978. Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
  979. the necessary conditions are met, else None.
  980. """
  981. if speculative_model is None:
  982. if num_speculative_tokens is not None:
  983. raise ValueError("num_speculative_tokens was provided without "
  984. "speculative_model.")
  985. return None
  986. if (speculative_disable_by_batch_size is not None
  987. and speculative_disable_by_batch_size < 2):
  988. raise ValueError("Expected the batch size threshold of disabling "
  989. "speculative decoding is > 1, but got "
  990. f"{speculative_disable_by_batch_size=}")
  991. if enable_chunked_prefill:
  992. raise ValueError(
  993. "Speculative decoding and chunked prefill are "
  994. f"currently mutually exclusive ({enable_chunked_prefill=}).")
  995. if not use_v2_block_manager:
  996. raise ValueError(
  997. "Speculative decoding requires usage of the V2 "
  998. "block manager. Enable it with --use-v2-block-manager.")
  999. # TODO: The user should be able to specify revision/quantization/max
  1000. # model len for the draft model. It is not currently supported.
  1001. draft_revision = None
  1002. draft_code_revision = None
  1003. draft_quantization = None
  1004. if speculative_model == "[ngram]":
  1005. if ngram_prompt_lookup_min is None:
  1006. ngram_prompt_lookup_min = 1
  1007. if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
  1008. raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
  1009. if ngram_prompt_lookup_min < 1:
  1010. raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
  1011. if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
  1012. raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
  1013. f"larger than {ngram_prompt_lookup_max=}")
  1014. # TODO: current we still need extract vocab_size from target model
  1015. # config, in future, we may try refactoring it out, and set
  1016. # draft related config as None here.
  1017. draft_model_config = target_model_config
  1018. draft_parallel_config = target_parallel_config
  1019. else:
  1020. ngram_prompt_lookup_max = 0
  1021. ngram_prompt_lookup_min = 0
  1022. draft_model_config = ModelConfig(
  1023. model=speculative_model,
  1024. tokenizer=target_model_config.tokenizer,
  1025. tokenizer_mode=target_model_config.tokenizer_mode,
  1026. trust_remote_code=target_model_config.trust_remote_code,
  1027. dtype=target_model_config.dtype,
  1028. seed=target_model_config.seed,
  1029. revision=draft_revision,
  1030. code_revision=draft_code_revision,
  1031. tokenizer_revision=target_model_config.tokenizer_revision,
  1032. max_model_len=None,
  1033. quantization=draft_quantization,
  1034. enforce_eager=target_model_config.enforce_eager,
  1035. max_seq_len_to_capture=target_model_config.
  1036. max_seq_len_to_capture,
  1037. max_logprobs=target_model_config.max_logprobs,
  1038. )
  1039. draft_hf_config = draft_model_config.hf_config
  1040. if (num_speculative_tokens is not None
  1041. and hasattr(draft_hf_config, "num_lookahead_tokens")):
  1042. draft_hf_config.num_lookahead_tokens = num_speculative_tokens
  1043. n_predict = getattr(draft_hf_config, "n_predict", None)
  1044. if n_predict is not None:
  1045. if num_speculative_tokens is None:
  1046. # Default to max value defined in draft model config.
  1047. num_speculative_tokens = n_predict
  1048. elif num_speculative_tokens > n_predict:
  1049. # Verify provided value doesn't exceed the maximum
  1050. # supported by the draft model.
  1051. raise ValueError(
  1052. "This speculative model supports a maximum of "
  1053. f"num_speculative_tokens={n_predict}, but "
  1054. f"{num_speculative_tokens=} was provided.")
  1055. draft_model_config.max_model_len = (
  1056. SpeculativeConfig._maybe_override_draft_max_model_len(
  1057. speculative_max_model_len,
  1058. draft_model_config.max_model_len,
  1059. target_model_config.max_model_len,
  1060. ))
  1061. draft_parallel_config = (
  1062. SpeculativeConfig.create_draft_parallel_config(
  1063. target_parallel_config,
  1064. speculative_draft_tensor_parallel_size))
  1065. if num_speculative_tokens is None:
  1066. raise ValueError(
  1067. "num_speculative_tokens must be provided with "
  1068. "speculative_model unless the draft model config contains an "
  1069. "n_predict parameter.")
  1070. if typical_acceptance_sampler_posterior_threshold is None:
  1071. typical_acceptance_sampler_posterior_threshold = 0.09
  1072. if typical_acceptance_sampler_posterior_alpha is None:
  1073. typical_acceptance_sampler_posterior_alpha = 0.3
  1074. if disable_logprobs is None:
  1075. disable_logprobs = True
  1076. return SpeculativeConfig(
  1077. draft_model_config,
  1078. draft_parallel_config,
  1079. num_speculative_tokens,
  1080. speculative_disable_by_batch_size,
  1081. ngram_prompt_lookup_max,
  1082. ngram_prompt_lookup_min,
  1083. draft_token_acceptance_method=draft_token_acceptance_method,
  1084. typical_acceptance_sampler_posterior_threshold=\
  1085. typical_acceptance_sampler_posterior_threshold,
  1086. typical_acceptance_sampler_posterior_alpha=\
  1087. typical_acceptance_sampler_posterior_alpha,
  1088. disable_logprobs=disable_logprobs,
  1089. disable_log_stats=disable_log_stats,
  1090. )
  1091. @staticmethod
  1092. def _maybe_override_draft_max_model_len(
  1093. speculative_max_model_len: Optional[int],
  1094. draft_max_model_len: int,
  1095. target_max_model_len: int,
  1096. ) -> int:
  1097. """Determine the max sequence len for the draft model. This is usually
  1098. the draft_max_model_len, but may be the target_max_model_len if it is
  1099. less than the draft_max_model_len, or may be speculative_max_model_len
  1100. if it is specified.
  1101. This is necessary so that sequences do not exceed the capacity of the
  1102. draft model or the target model.
  1103. speculative_max_model_len is mainly used for testing that sequences can
  1104. skip speculation.
  1105. """
  1106. if speculative_max_model_len is not None:
  1107. if speculative_max_model_len > draft_max_model_len:
  1108. raise ValueError(f"{speculative_max_model_len=} cannot be "
  1109. f"larger than {draft_max_model_len=}")
  1110. if speculative_max_model_len > target_max_model_len:
  1111. raise ValueError(f"{speculative_max_model_len=} cannot be "
  1112. f"larger than {target_max_model_len=}")
  1113. return speculative_max_model_len
  1114. return min(
  1115. draft_max_model_len,
  1116. target_max_model_len,
  1117. )
  1118. @staticmethod
  1119. def create_draft_parallel_config(
  1120. target_parallel_config: ParallelConfig,
  1121. speculative_draft_tensor_parallel_size: Optional[int]
  1122. ) -> ParallelConfig:
  1123. """Create a parallel config for use by the draft worker.
  1124. This is mostly a copy of the target parallel config, except the tp_size.
  1125. """
  1126. if speculative_draft_tensor_parallel_size is None:
  1127. speculative_draft_tensor_parallel_size = \
  1128. target_parallel_config.tensor_parallel_size
  1129. elif speculative_draft_tensor_parallel_size != 1:
  1130. # TODO: allow tp values larger than 1
  1131. raise ValueError(
  1132. f"{speculative_draft_tensor_parallel_size=} cannot be"
  1133. f"other value than 1")
  1134. draft_parallel_config = ParallelConfig(
  1135. pipeline_parallel_size=target_parallel_config.
  1136. pipeline_parallel_size,
  1137. tensor_parallel_size=speculative_draft_tensor_parallel_size,
  1138. distributed_executor_backend=target_parallel_config.
  1139. distributed_executor_backend,
  1140. max_parallel_loading_workers=target_parallel_config.
  1141. max_parallel_loading_workers,
  1142. disable_custom_all_reduce=target_parallel_config.
  1143. disable_custom_all_reduce,
  1144. tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
  1145. ray_workers_use_nsight=target_parallel_config.
  1146. ray_workers_use_nsight,
  1147. placement_group=target_parallel_config.placement_group,
  1148. )
  1149. return draft_parallel_config
  1150. def __init__(
  1151. self,
  1152. draft_model_config: ModelConfig,
  1153. draft_parallel_config: ParallelConfig,
  1154. num_speculative_tokens: int,
  1155. speculative_disable_by_batch_size: Optional[int],
  1156. ngram_prompt_lookup_max: Optional[int],
  1157. ngram_prompt_lookup_min: Optional[int],
  1158. draft_token_acceptance_method: str,
  1159. typical_acceptance_sampler_posterior_threshold: float,
  1160. typical_acceptance_sampler_posterior_alpha: float,
  1161. disable_logprobs: bool,
  1162. disable_log_stats: bool,
  1163. ):
  1164. """Create a SpeculativeConfig object.
  1165. Args:
  1166. draft_model_config: ModelConfig for the draft model.
  1167. draft_parallel_config: ParallelConfig for the draft model.
  1168. num_speculative_tokens: The number of tokens to sample from the
  1169. draft model before scoring with the target model.
  1170. speculative_disable_by_batch_size: Disable speculative
  1171. decoding for new incoming requests when the number of
  1172. enqueue requests is larger than this value.
  1173. ngram_prompt_lookup_max: Max size of ngram token window.
  1174. ngram_prompt_lookup_min: Min size of ngram token window.
  1175. draft_token_acceptance_method (str): The method to use for
  1176. accepting draft tokens. This can take two possible
  1177. values 'rejection_sampler' and 'typical_acceptance_sampler'
  1178. for RejectionSampler and TypicalAcceptanceSampler
  1179. respectively.
  1180. typical_acceptance_sampler_posterior_threshold (Optional[float]):
  1181. A threshold value that sets a lower bound on the posterior
  1182. probability of a token in the target model for it to be
  1183. accepted. This threshold is used only when we use the
  1184. TypicalAcceptanceSampler for token acceptance.
  1185. typical_acceptance_sampler_posterior_alpha (Optional[float]):
  1186. A scaling factor for the entropy-based threshold in the
  1187. TypicalAcceptanceSampler.
  1188. disable_logprobs: If set to True, token log probabilities will not
  1189. be returned even if requested by sampling parameters. This
  1190. reduces latency by skipping logprob calculation in proposal
  1191. sampling, target sampling, and after accepted tokens are
  1192. determined. If set to False, log probabilities will be
  1193. returned.
  1194. disable_log_stats: Whether to disable periodic printing of stage
  1195. times in speculative decoding.
  1196. """
  1197. self.draft_model_config = draft_model_config
  1198. self.draft_parallel_config = draft_parallel_config
  1199. self.num_speculative_tokens = num_speculative_tokens
  1200. self.speculative_disable_by_batch_size = \
  1201. speculative_disable_by_batch_size
  1202. self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
  1203. self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
  1204. self.draft_token_acceptance_method = draft_token_acceptance_method
  1205. self.typical_acceptance_sampler_posterior_threshold = \
  1206. typical_acceptance_sampler_posterior_threshold
  1207. self.typical_acceptance_sampler_posterior_alpha = \
  1208. typical_acceptance_sampler_posterior_alpha
  1209. self.disable_logprobs = disable_logprobs
  1210. self.disable_log_stats = disable_log_stats
  1211. self._verify_args()
  1212. def _verify_args(self) -> None:
  1213. if self.num_speculative_tokens <= 0:
  1214. raise ValueError("Expected num_speculative_tokens to be greater "
  1215. f"than zero ({self.num_speculative_tokens}).")
  1216. if self.draft_model_config:
  1217. self.draft_model_config.verify_with_parallel_config(
  1218. self.draft_parallel_config)
  1219. # Validate and set draft token acceptance related settings.
  1220. if (self.draft_token_acceptance_method is None):
  1221. raise ValueError("draft_token_acceptance_method is not set. "
  1222. "Expected values are rejection_sampler or "
  1223. "typical_acceptance_sampler.")
  1224. if (self.draft_token_acceptance_method != 'rejection_sampler'
  1225. and self.draft_token_acceptance_method !=
  1226. 'typical_acceptance_sampler'):
  1227. raise ValueError(
  1228. "Expected draft_token_acceptance_method to be either "
  1229. "rejection_sampler or typical_acceptance_sampler. Instead it "
  1230. f"is {self.draft_token_acceptance_method}")
  1231. if (self.typical_acceptance_sampler_posterior_threshold < 0
  1232. or self.typical_acceptance_sampler_posterior_alpha < 0):
  1233. raise ValueError(
  1234. "Expected typical_acceptance_sampler_posterior_threshold "
  1235. "and typical_acceptance_sampler_posterior_alpha to be > 0. "
  1236. "Instead found "
  1237. f"typical_acceptance_sampler_posterior_threshold = "
  1238. f"{self.typical_acceptance_sampler_posterior_threshold} and "
  1239. f"typical_acceptance_sampler_posterior_alpha = "
  1240. f"{self.typical_acceptance_sampler_posterior_alpha}")
  1241. @property
  1242. def num_lookahead_slots(self) -> int:
  1243. """The number of additional slots the scheduler should allocate per
  1244. step, in addition to the slots allocated for each known token.
  1245. This is equal to the number of speculative tokens, as each speculative
  1246. token must be scored.
  1247. """
  1248. return self.num_speculative_tokens
  1249. def __repr__(self) -> str:
  1250. if self.ngram_prompt_lookup_max > 0:
  1251. draft_model = "[ngram]"
  1252. else:
  1253. draft_model = self.draft_model_config.model
  1254. num_spec_tokens = self.num_speculative_tokens
  1255. return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
  1256. @dataclass
  1257. class LoRAConfig:
  1258. max_lora_rank: int
  1259. max_loras: int
  1260. fully_sharded_loras: bool = False
  1261. max_cpu_loras: Optional[int] = None
  1262. lora_dtype: Optional[torch.dtype] = None
  1263. lora_extra_vocab_size: int = 256
  1264. # This is a constant.
  1265. lora_vocab_padding_size: ClassVar[int] = 256
  1266. long_lora_scaling_factors: Optional[Tuple[float]] = None
  1267. def __post_init__(self):
  1268. # Setting the maximum rank to 256 should be able to satisfy the vast
  1269. # majority of applications.
  1270. possible_max_ranks = (8, 16, 32, 64, 128, 256)
  1271. possible_lora_extra_vocab_size = (0, 256, 512)
  1272. if self.max_lora_rank not in possible_max_ranks:
  1273. raise ValueError(
  1274. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  1275. f"{possible_max_ranks}.")
  1276. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  1277. raise ValueError(
  1278. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  1279. f"must be one of {possible_lora_extra_vocab_size}.")
  1280. if self.max_loras < 1:
  1281. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  1282. if self.max_cpu_loras is None:
  1283. self.max_cpu_loras = self.max_loras
  1284. elif self.max_cpu_loras < self.max_loras:
  1285. raise ValueError(
  1286. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  1287. f"max_loras ({self.max_loras})")
  1288. def verify_with_model_config(self, model_config: ModelConfig):
  1289. if self.lora_dtype in (None, "auto"):
  1290. self.lora_dtype = model_config.dtype
  1291. elif isinstance(self.lora_dtype, str):
  1292. self.lora_dtype = getattr(torch, self.lora_dtype)
  1293. if model_config.quantization and model_config.quantization not in [
  1294. "awq", "gptq"
  1295. ]:
  1296. # TODO support all other quants
  1297. logger.warning(f"{model_config.quantization} quantization is not "
  1298. "tested with LoRA yet.")
  1299. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  1300. if scheduler_config.chunked_prefill_enabled:
  1301. raise ValueError("LoRA is not supported with chunked prefill yet.")
  1302. def verify_with_parallel_config(self, parallel_config: ParallelConfig):
  1303. if self.lora_vocab_padding_size % parallel_config.world_size != 0:
  1304. raise ValueError("LoRA vocab padding size must be divisible "
  1305. "by world size.")
  1306. @dataclass
  1307. class PromptAdapterConfig:
  1308. max_prompt_adapters: int
  1309. max_prompt_adapter_token: int
  1310. max_cpu_prompt_adapters: Optional[int] = None
  1311. prompt_adapter_dtype: Optional[torch.dtype] = None
  1312. def __post_init__(self):
  1313. library_name = 'peft'
  1314. try:
  1315. __import__(library_name)
  1316. except ImportError as e:
  1317. raise ImportError(
  1318. f"'{library_name}' is not installed for prompt adapter support."
  1319. f"Please install it using 'pip install {library_name}'."
  1320. ) from e
  1321. if self.max_prompt_adapters < 1:
  1322. raise ValueError(f"max_prompt_adapters "
  1323. f"({self.max_prompt_adapters}) must be >= 1.")
  1324. if self.max_prompt_adapter_token == 0:
  1325. raise ValueError("max_prompt_adapter_token must be set.")
  1326. if self.max_cpu_prompt_adapters is None:
  1327. self.max_cpu_prompt_adapters = self.max_prompt_adapters
  1328. def verify_with_model_config(self, model_config: ModelConfig):
  1329. if self.prompt_adapter_dtype in (None, "auto"):
  1330. self.prompt_adapter_dtype = model_config.dtype
  1331. elif isinstance(self.prompt_adapter_dtype, str):
  1332. self.prompt_adapter_dtype = getattr(torch,
  1333. self.prompt_adapter_dtype)
  1334. @dataclass
  1335. class MultiModalConfig:
  1336. """Configs the input data format and how models should run for
  1337. multimodal models."""
  1338. # TODO: Add configs to init vision tower or not.
  1339. pass
  1340. _STR_DTYPE_TO_TORCH_DTYPE = {
  1341. "half": torch.float16,
  1342. "float16": torch.float16,
  1343. "float": torch.float32,
  1344. "float32": torch.float32,
  1345. "bfloat16": torch.bfloat16,
  1346. }
  1347. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  1348. def _get_and_verify_dtype(
  1349. config: PretrainedConfig,
  1350. dtype: Union[str, torch.dtype],
  1351. ) -> torch.dtype:
  1352. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  1353. # because config.torch_dtype can be None.
  1354. config_dtype = getattr(config, "torch_dtype", None)
  1355. if config_dtype is None:
  1356. config_dtype = torch.float32
  1357. if isinstance(dtype, str):
  1358. dtype = dtype.lower()
  1359. if dtype == "auto":
  1360. if config_dtype == torch.float32:
  1361. if config.model_type == "gemma2":
  1362. logger.info(
  1363. "For Gemma 2, we downcast float32 to bfloat16 instead "
  1364. "of float16 by default. Please specify `dtype` if you "
  1365. "want to use float16.")
  1366. torch_dtype = torch.bfloat16
  1367. else:
  1368. # Following the common practice, we use float16 for float32
  1369. # models.
  1370. torch_dtype = torch.float16
  1371. else:
  1372. torch_dtype = config_dtype
  1373. else:
  1374. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  1375. raise ValueError(f"Unknown dtype: {dtype}")
  1376. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  1377. elif isinstance(dtype, torch.dtype):
  1378. torch_dtype = dtype
  1379. else:
  1380. raise ValueError(f"Unknown dtype: {dtype}")
  1381. if is_hip() and torch_dtype == torch.float32:
  1382. rocm_supported_dtypes = [
  1383. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  1384. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  1385. ]
  1386. raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
  1387. f"Supported dtypes are {rocm_supported_dtypes}")
  1388. # Verify the dtype.
  1389. if torch_dtype != config_dtype:
  1390. if torch_dtype == torch.float32:
  1391. # Upcasting to float32 is allowed.
  1392. pass
  1393. elif config_dtype == torch.float32:
  1394. # Downcasting from float32 to float16 or bfloat16 is allowed.
  1395. pass
  1396. else:
  1397. # Casting between float16 and bfloat16 is allowed with a warning.
  1398. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  1399. return torch_dtype
  1400. def _get_and_verify_max_len(
  1401. hf_config: PretrainedConfig,
  1402. max_model_len: Optional[int],
  1403. disable_sliding_window: bool,
  1404. sliding_window_len: Optional[int],
  1405. rope_scaling_arg: Optional[Dict[str, Any]],
  1406. ) -> int:
  1407. """Get and verify the model's maximum length."""
  1408. derived_max_model_len = float("inf")
  1409. possible_keys = [
  1410. # Cohere: needs to prioritize this over "max_position_embeddings"
  1411. "model_max_length",
  1412. # OPT
  1413. "max_position_embeddings",
  1414. # GPT-2
  1415. "n_positions",
  1416. # MPT
  1417. "max_seq_len",
  1418. # ChatGLM2
  1419. "seq_length",
  1420. # Command-R
  1421. "model_max_length",
  1422. # Others
  1423. "max_sequence_length",
  1424. "max_seq_length",
  1425. "seq_len",
  1426. ]
  1427. # Choose the smallest "max_length" from the possible keys.
  1428. max_len_key = None
  1429. for key in possible_keys:
  1430. max_len = getattr(hf_config, key, None)
  1431. if max_len is not None:
  1432. max_len_key = key if max_len < derived_max_model_len \
  1433. else max_len_key
  1434. derived_max_model_len = min(derived_max_model_len, max_len)
  1435. # If sliding window is manually disabled, max_length should be less
  1436. # than the sliding window length in the model config.
  1437. if disable_sliding_window and sliding_window_len is not None:
  1438. max_len_key = "sliding_window" \
  1439. if sliding_window_len < derived_max_model_len else max_len_key
  1440. derived_max_model_len = min(derived_max_model_len, sliding_window_len)
  1441. # If none of the keys were found in the config, use a default and
  1442. # log a warning.
  1443. if derived_max_model_len == float("inf"):
  1444. if max_model_len is not None:
  1445. # If max_model_len is specified, we use it.
  1446. return max_model_len
  1447. default_max_len = 2048
  1448. logger.warning(
  1449. "The model's config.json does not contain any of the following "
  1450. "keys to determine the original maximum length of the model: "
  1451. f"{possible_keys}. Assuming the model's maximum length is "
  1452. f"{default_max_len}.")
  1453. derived_max_model_len = default_max_len
  1454. rope_scaling = getattr(hf_config, "rope_scaling", None)
  1455. if rope_scaling is not None:
  1456. rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
  1457. if rope_type not in {"su", "longrope", "llama3"}:
  1458. if disable_sliding_window:
  1459. # TODO: Find a model that supports rope_scaling
  1460. # with sliding window to see if this case should be allowed.
  1461. raise NotImplementedError(
  1462. "Disabling sliding window is not supported for models "
  1463. "with rope_scaling. Please raise an issue so we can "
  1464. "investigate.")
  1465. assert "factor" in rope_scaling
  1466. scaling_factor = rope_scaling["factor"]
  1467. if rope_type == "yarn":
  1468. derived_max_model_len = rope_scaling[
  1469. "original_max_position_embeddings"]
  1470. derived_max_model_len *= scaling_factor
  1471. if max_model_len is None:
  1472. max_model_len = derived_max_model_len
  1473. elif max_model_len > derived_max_model_len and rope_scaling_arg is None:
  1474. raise ValueError(
  1475. f"User-specified max_model_len {max_model_len} is higher than "
  1476. f"the original {derived_max_model_len}. "
  1477. "Please provide a rope_scaling dict to scale the model.")
  1478. elif max_model_len > derived_max_model_len and rope_scaling_arg is not None:
  1479. # hope this works
  1480. logger.warning(
  1481. f"User-specified max_model_len {max_model_len} is higher than "
  1482. f"the original {derived_max_model_len}. "
  1483. "Attempting to use RoPE scaling with the provided rope_scaling "
  1484. "dict.")
  1485. derived_max_model_len = max_model_len
  1486. return int(max_model_len)
  1487. def get_served_model_name(model: str,
  1488. served_model_name: Optional[Union[str, List[str]]]):
  1489. """
  1490. If the input is a non-empty list, the first model_name in
  1491. `served_model_name` is taken.
  1492. If the input is a non-empty string, it is used directly.
  1493. For cases where the input is either an empty string or an
  1494. empty list, the fallback is to use `self.model`.
  1495. """
  1496. if not served_model_name:
  1497. return model
  1498. if isinstance(served_model_name, list):
  1499. return served_model_name[0]
  1500. return served_model_name
  1501. @dataclass
  1502. class DecodingConfig:
  1503. """Dataclass which contains the decoding strategy of the engine"""
  1504. # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
  1505. guided_decoding_backend: str = 'outlines'
  1506. def __post_init__(self):
  1507. valid_guided_backends = ['outlines', 'lm-format-enforcer']
  1508. backend = self.guided_decoding_backend
  1509. if backend not in valid_guided_backends:
  1510. raise ValueError(f"Invalid guided_decoding_backend '{backend},"
  1511. f"must be one of {valid_guided_backends}")
  1512. @dataclass(frozen=True)
  1513. class EngineConfig:
  1514. """Dataclass which contains all engine-related configuration. This
  1515. simplifies passing around the distinct configurations in the codebase.
  1516. """
  1517. model_config: ModelConfig
  1518. cache_config: CacheConfig
  1519. parallel_config: ParallelConfig
  1520. scheduler_config: SchedulerConfig
  1521. device_config: DeviceConfig
  1522. load_config: LoadConfig
  1523. lora_config: Optional[LoRAConfig]
  1524. multimodal_config: Optional[MultiModalConfig]
  1525. speculative_config: Optional[SpeculativeConfig]
  1526. decoding_config: Optional[DecodingConfig]
  1527. prompt_adapter_config: Optional[PromptAdapterConfig]
  1528. def __post_init__(self):
  1529. """Verify configs are valid & consistent with each other.
  1530. """
  1531. self.model_config.verify_with_parallel_config(self.parallel_config)
  1532. self.cache_config.verify_with_parallel_config(self.parallel_config)
  1533. if self.lora_config:
  1534. self.lora_config.verify_with_model_config(self.model_config)
  1535. self.lora_config.verify_with_scheduler_config(
  1536. self.scheduler_config)
  1537. self.lora_config.verify_with_parallel_config(self.parallel_config)
  1538. if self.prompt_adapter_config:
  1539. self.prompt_adapter_config.verify_with_model_config(
  1540. self.model_config)
  1541. def to_dict(self):
  1542. """Return the configs as a dictionary, for use in **kwargs.
  1543. """
  1544. return dict(
  1545. (field.name, getattr(self, field.name)) for field in fields(self))