config.py 89 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018
  1. import enum
  2. import json
  3. import os
  4. from dataclasses import dataclass, field, fields
  5. from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
  6. Optional, Tuple, Type, Union)
  7. import torch
  8. from loguru import logger
  9. from transformers import PretrainedConfig
  10. import aphrodite.common.envs as envs
  11. from aphrodite.common.utils import (GiB_bytes, cuda_device_count_stateless,
  12. get_cpu_memory, is_cpu, is_hip, is_neuron,
  13. is_openvino, is_xpu, print_warning_once)
  14. from aphrodite.distributed import get_current_tp_rank_partition_size
  15. from aphrodite.modeling.models import ModelRegistry
  16. from aphrodite.platforms import current_platform
  17. from aphrodite.quantization import QUANTIZATION_METHODS
  18. from aphrodite.transformers_utils.config import (ConfigFormat, get_config,
  19. get_hf_image_processor_config,
  20. get_hf_text_config)
  21. from aphrodite.triton_utils import HAS_TRITON
  22. if TYPE_CHECKING:
  23. from ray.util.placement_group import PlacementGroup
  24. from aphrodite.executor.executor_base import ExecutorBase
  25. from aphrodite.modeling.model_loader.loader import BaseModelLoader
  26. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import ( # noqa: E501
  27. BaseTokenizerGroup)
  28. # If true, will load models from ModelScope instead of Hugging Face Hub.
  29. APHRODITE_USE_MODELSCOPE = envs.APHRODITE_USE_MODELSCOPE
  30. _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
  31. _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
  32. _PP_SUPPORTED_MODELS = [
  33. "AquilaModel",
  34. "AquilaForCausalLM",
  35. "InternLMForCausalLM",
  36. "LlamaForCausalLM",
  37. "LLaMAForCausalLM",
  38. "MistralForCausalLM",
  39. "Phi3ForCausalLM",
  40. "MixtralForCausalLM",
  41. "NemotronForCausalLM",
  42. "Qwen2ForCausalLM",
  43. "Qwen2MoeForCausalLM",
  44. "InternLM2ForCausalLM",
  45. "InternVLChatModel",
  46. ]
  47. _OPTIMIZED_QUANTS = [
  48. "awq_marlin",
  49. "compressed-tensors",
  50. "compressed_tensors",
  51. "experts_int8",
  52. "fbgemm_fp8",
  53. "fp2",
  54. "fp3",
  55. "fp4",
  56. "fp5",
  57. "fp6",
  58. "fp7",
  59. "fp8",
  60. "gptq_marlin",
  61. "gptq_marlin_24",
  62. "marlin",
  63. "modelopt",
  64. "quant_llm",
  65. ]
  66. class ModelConfig:
  67. """Configuration for the model.
  68. Args:
  69. model: Name or path of the huggingface model to use.
  70. It is also used as the content for `model_name` tag in metrics
  71. output when `served_model_name` is not specified.
  72. tokenizer: Name or path of the huggingface tokenizer to use.
  73. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  74. available, "slow" will always use the slow tokenizer, and
  75. "mistral" will always use the tokenizer from `mistral_common`.
  76. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  77. downloading the model and tokenizer.
  78. dtype: Data type for model weights and activations. The "auto" option
  79. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  80. for BF16 models.
  81. seed: Random seed for reproducibility.
  82. revision: The specific model version to use. It can be a branch name,
  83. a tag name, or a commit id. If unspecified, will use the default
  84. version.
  85. code_revision: The specific revision to use for the model code on
  86. Hugging Face Hub. It can be a branch name, a tag name, or a
  87. commit id. If unspecified, will use the default version.
  88. rope_scaling: Dictionary containing the scaling configuration for the
  89. RoPE embeddings. When using this flag, don't update
  90. `max_position_embeddings` to the expected new maximum.
  91. tokenizer_revision: The specific tokenizer version to use. It can be a
  92. branch name, a tag name, or a commit id. If unspecified, will use
  93. the default version.
  94. max_model_len: Maximum length of a sequence (including prompt and
  95. output). If None, will be derived from the model.
  96. quantization: Quantization method that was used to quantize the model
  97. weights. If None, we assume the model weights are not quantized.
  98. deepspeed_fp_bits: Number of bits to use for DeepSpeed FP quantization.
  99. Supported number of bits are: 4, 6, 8, 12.
  100. quant_llm_fp_bits: Number of bits to use for QuantLLM FP quantization.
  101. Supported number of bits are: 5, 6, 7.
  102. quantization_param_path: Path to JSON file containing scaling factors.
  103. Used to load KV cache scaling factors into the model when KV cache
  104. type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
  105. be used to load activation and weight scaling factors when the
  106. model dtype is FP8_E4M3 on ROCm.
  107. enforce_eager: Whether to enforce eager execution. If True, we will
  108. disable CUDA graph and always execute the model in eager mode.
  109. If False, we will use CUDA graph and eager execution in hybrid.
  110. If None, the user did not specify, so default to False.
  111. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  112. When a sequence has context length larger than this, we fall back
  113. to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
  114. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
  115. When a sequence has context length larger than this, we fall back
  116. to eager mode. Additionally for encoder-decoder models, if the
  117. sequence length of the encoder input is larger than this, we fall
  118. back to the eager mode.
  119. disable_sliding_window: Whether to disable sliding window. If True,
  120. we will disable the sliding window functionality of the model.
  121. If the model does not support sliding window, this argument is
  122. ignored.
  123. skip_tokenizer_init: If true, skip initialization of tokenizer and
  124. detokenizer.
  125. served_model_name: The model name used in metrics tag `model_name`,
  126. matches the model name exposed via the APIs. If multiple model
  127. names provided, the first name will be used. If not specified,
  128. the model name will be the same as `model`.
  129. limit_mm_per_prompt: Maximum number of data instances per modality
  130. per prompt. Only applicable for multimodal models.
  131. config_format: The config format which will be loaded. Defaults to
  132. 'auto' which defaults to 'hf'.
  133. override_neuron_config: Initialize non default neuron config or
  134. override default neuron config that are specific to Neuron devices,
  135. this argument will be used to configure the neuron config that
  136. can not be gathered from the Aphrodite arguments.
  137. """
  138. def __init__(
  139. self,
  140. model: str,
  141. tokenizer: str,
  142. tokenizer_mode: str,
  143. trust_remote_code: bool,
  144. dtype: Union[str, torch.dtype],
  145. seed: int,
  146. revision: Optional[str] = None,
  147. code_revision: Optional[str] = None,
  148. rope_scaling: Optional[dict] = None,
  149. rope_theta: Optional[float] = None,
  150. tokenizer_revision: Optional[str] = None,
  151. max_model_len: Optional[int] = None,
  152. spec_target_max_model_len: Optional[int] = None,
  153. quantization: Optional[str] = None,
  154. deepspeed_fp_bits: Optional[int] = None,
  155. quant_llm_fp_bits: Optional[int] = None,
  156. quant_llm_exp_bits: Optional[int] = None,
  157. quantization_param_path: Optional[str] = None,
  158. enforce_eager: Optional[bool] = None,
  159. max_context_len_to_capture: Optional[int] = None,
  160. max_seq_len_to_capture: Optional[int] = None,
  161. max_logprobs: int = 5,
  162. disable_sliding_window: bool = False,
  163. skip_tokenizer_init: bool = False,
  164. served_model_name: Optional[Union[str, List[str]]] = None,
  165. limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
  166. use_async_output_proc: bool = True,
  167. config_format: ConfigFormat = ConfigFormat.AUTO,
  168. override_neuron_config: Optional[Dict[str, Any]] = None
  169. ) -> None:
  170. self.model = model
  171. self.tokenizer = tokenizer
  172. self.tokenizer_mode = tokenizer_mode
  173. self.trust_remote_code = trust_remote_code
  174. self.seed = seed
  175. self.revision = revision
  176. self.code_revision = code_revision
  177. self.rope_scaling = rope_scaling
  178. self.rope_theta = rope_theta
  179. # The tokenizer version is consistent with the model version by default.
  180. if tokenizer_revision is None:
  181. self.tokenizer_revision = revision
  182. else:
  183. self.tokenizer_revision = tokenizer_revision
  184. self.quantization = quantization
  185. self.deepspeed_fp_bits = deepspeed_fp_bits
  186. self.quant_llm_fp_bits = quant_llm_fp_bits
  187. self.quant_llm_exp_bits = quant_llm_exp_bits
  188. self.quantization_param_path = quantization_param_path
  189. self.enforce_eager = enforce_eager
  190. self.max_context_len_to_capture = max_context_len_to_capture
  191. if self.max_context_len_to_capture is not None:
  192. raise ValueError("`max_context_len_to_capture` is deprecated. "
  193. "Use `max_seq_len_to_capture` instead.")
  194. self.max_seq_len_to_capture = (max_seq_len_to_capture
  195. or max_context_len_to_capture)
  196. self.max_logprobs = max_logprobs
  197. self.disable_sliding_window = disable_sliding_window
  198. self.skip_tokenizer_init = skip_tokenizer_init
  199. self.hf_config = get_config(self.model, trust_remote_code, revision,
  200. code_revision, rope_scaling, rope_theta,
  201. config_format)
  202. self.hf_text_config = get_hf_text_config(self.hf_config)
  203. self.hf_image_processor_config = get_hf_image_processor_config(
  204. self.model, revision)
  205. self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
  206. self.use_async_output_proc = use_async_output_proc
  207. # Set enforce_eager to False if the value is unset.
  208. if self.enforce_eager is None:
  209. self.enforce_eager = False
  210. sliding_window = getattr(self.hf_text_config, "sliding_window", None)
  211. has_interleaved_attention = (sliding_window is not None) and (
  212. isinstance(sliding_window, list) or
  213. (self.hf_text_config.model_type in ["gemma2"]))
  214. if (not self.disable_sliding_window and has_interleaved_attention):
  215. sliding_window_len_min = get_min_sliding_window(
  216. self.hf_text_config.sliding_window)
  217. print_warning_once(
  218. f"{self.hf_text_config.model_type} has interleaved attention, "
  219. "which is currently not supported by vLLM. Disabling sliding "
  220. "window and capping the max length to the sliding window size "
  221. f"({sliding_window_len_min}).")
  222. self.disable_sliding_window = True
  223. self.max_model_len = _get_and_verify_max_len(
  224. hf_config=self.hf_text_config,
  225. max_model_len=max_model_len,
  226. disable_sliding_window=self.disable_sliding_window,
  227. sliding_window_len=self.get_hf_config_sliding_window(),
  228. spec_target_max_model_len=spec_target_max_model_len,
  229. rope_scaling_arg=self.rope_scaling)
  230. self.served_model_name = get_served_model_name(model,
  231. served_model_name)
  232. self.multimodal_config = self._init_multimodal_config(
  233. limit_mm_per_prompt)
  234. if not self.skip_tokenizer_init:
  235. self._verify_tokenizer_mode()
  236. self.override_neuron_config = override_neuron_config if is_neuron(
  237. ) else None
  238. self._verify_embedding_mode()
  239. self._verify_quantization()
  240. self._verify_cuda_graph()
  241. def _init_multimodal_config(
  242. self, limit_mm_per_prompt: Optional[Mapping[str, int]]
  243. ) -> Optional["MultiModalConfig"]:
  244. architectures = getattr(self.hf_config, "architectures", [])
  245. if any(
  246. ModelRegistry.is_multimodal_model(arch)
  247. for arch in architectures):
  248. return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
  249. else:
  250. if limit_mm_per_prompt:
  251. raise ValueError(
  252. "limit_mm_per_prompt is only supported for multimodal "
  253. "models.")
  254. return None
  255. def _verify_tokenizer_mode(self) -> None:
  256. tokenizer_mode = self.tokenizer_mode.lower()
  257. if tokenizer_mode not in ["auto", "slow", "mistral"]:
  258. raise ValueError(
  259. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  260. "either 'auto', 'slow' or 'mistral'.")
  261. self.tokenizer_mode = tokenizer_mode
  262. def _verify_embedding_mode(self) -> None:
  263. architectures = getattr(self.hf_config, "architectures", [])
  264. self.embedding_mode = any(
  265. ModelRegistry.is_embedding_model(arch) for arch in architectures)
  266. def _parse_quant_hf_config(self):
  267. quant_cfg = getattr(self.hf_config, "quantization_config", None)
  268. if quant_cfg is None:
  269. # compress-tensors uses a "compression_config" key
  270. quant_cfg = getattr(self.hf_config, "compression_config", None)
  271. return quant_cfg
  272. def _verify_quantization(self) -> None:
  273. supported_quantization = [*QUANTIZATION_METHODS]
  274. rocm_supported_quantization = [
  275. "awq", "gptq", "squeezellm", "fp8",
  276. "compressed_tensors", "compressed-tensors",
  277. "fbgemm_fp8"
  278. ]
  279. tpu_supported_quantization = ["tpu_int8"]
  280. neuron_supported_quantization = ["neuron_quant"]
  281. if self.quantization is not None:
  282. self.quantization = self.quantization.lower()
  283. # Parse quantization method from the HF model config, if available.
  284. quant_cfg = self._parse_quant_hf_config()
  285. if quant_cfg is not None:
  286. quant_method = quant_cfg.get("quant_method", "").lower()
  287. # Detect which checkpoint is it
  288. for _, method in QUANTIZATION_METHODS.items():
  289. quantization_override = method.override_quantization_method(
  290. quant_cfg, self.quantization)
  291. if quantization_override:
  292. if quantization_override == "awq_marlin":
  293. quant_method = quant_method
  294. logger.warning(
  295. "awq_marlin kernels are temporarily disabled, "
  296. "they will be re-enabled with a future release. "
  297. "Falling back to AWQ kernels.")
  298. else:
  299. quant_method = quantization_override
  300. self.quantization = quantization_override
  301. break
  302. # Verify quantization configurations.
  303. if self.quantization is None:
  304. self.quantization = quant_method
  305. elif self.quantization != quant_method:
  306. raise ValueError(
  307. "Quantization method specified in the model config "
  308. f"({quant_method}) does not match the quantization "
  309. f"method specified in the `quantization` argument "
  310. f"({self.quantization}).")
  311. if self.quantization == "deepspeedfp":
  312. gs = 32 if self.deepspeed_fp_bits == 4 else 128
  313. self.hf_config.quantization_config = {
  314. "bits": self.deepspeed_fp_bits,
  315. "group_size": int(os.environ.get("DEEPSPEED_GROUP_SIZE", gs)),
  316. "quant_method": "deepspeedfp"
  317. }
  318. VALID_QUANT_LLM_FP_BITS = [2, 3, 4, 5, 6, 7]
  319. VALID_QUANT_LLM_EXPONENTS = [1, 2, 3, 4, 5]
  320. # The formula is mantissa_bits = fp_bits - exp_bits - 1
  321. # The default exp_bits for each fp_bits are as follows:
  322. DEFAULT_EXP_BITS = {
  323. 2: 1,
  324. 3: 2,
  325. 4: 2,
  326. 5: 2,
  327. 6: 2,
  328. 7: 3,
  329. }
  330. if self.quantization == "quant_llm":
  331. if self.quant_llm_fp_bits is None:
  332. raise ValueError(
  333. "quant_llm_fp_bits must be specified when using "
  334. "quant_llm quantization."
  335. )
  336. if self.quant_llm_fp_bits not in VALID_QUANT_LLM_FP_BITS:
  337. raise ValueError(
  338. f"Invalid quant_llm_fp_bits: {self.quant_llm_fp_bits}. "
  339. f"Must be one of {VALID_QUANT_LLM_FP_BITS}."
  340. )
  341. if self.quant_llm_exp_bits is None:
  342. self.quant_llm_exp_bits = DEFAULT_EXP_BITS[
  343. self.quant_llm_fp_bits]
  344. else:
  345. if self.quant_llm_exp_bits not in VALID_QUANT_LLM_EXPONENTS:
  346. raise ValueError(
  347. f"Invalid exponent bits: {self.quant_llm_exp_bits}. "
  348. f"Must be one of {VALID_QUANT_LLM_EXPONENTS}."
  349. )
  350. self.hf_config.quantization_config = {
  351. "bits": self.quant_llm_fp_bits,
  352. "exp_bits": self.quant_llm_exp_bits,
  353. "quant_method": "quant_llm"
  354. }
  355. online_quant_methods = ["fp2", "fp3", "fp4", "fp5", "fp6", "fp7"]
  356. if self.quantization is not None and self.quantization in \
  357. online_quant_methods:
  358. fp_bits = int(self.quantization[2])
  359. if fp_bits not in VALID_QUANT_LLM_FP_BITS:
  360. raise ValueError(
  361. f"Invalid quant_llm_fp_bits: {fp_bits}. "
  362. f"Must be one of {VALID_QUANT_LLM_FP_BITS}."
  363. )
  364. if fp_bits in [2, 3]:
  365. logger.warning("FP2 and FP3 quantization methods lead to "
  366. "significant accuracy loss. Use them with "
  367. "caution. Model may be incoherent.")
  368. exp_bits = DEFAULT_EXP_BITS[fp_bits]
  369. self.hf_config.quantization_config = {
  370. "bits": fp_bits,
  371. "exp_bits": exp_bits,
  372. "quant_method": self.quantization
  373. }
  374. self.dtype = torch.float16
  375. self.enforce_eager = True
  376. if self.quantization is not None:
  377. if self.quantization not in supported_quantization:
  378. raise ValueError(
  379. f"Unknown quantization method: {self.quantization}. Must "
  380. f"be one of {supported_quantization}.")
  381. if is_hip(
  382. ) and self.quantization not in rocm_supported_quantization:
  383. raise ValueError(
  384. f"{self.quantization} quantization is currently not "
  385. "supported in ROCm.")
  386. if current_platform.is_tpu(
  387. ) and self.quantization not in tpu_supported_quantization:
  388. raise ValueError(
  389. f"{self.quantization} quantization is currently not "
  390. f"supported in TPU Backend.")
  391. if self.quantization not in _OPTIMIZED_QUANTS:
  392. logger.warning(
  393. f"{self.quantization} quantization is not fully "
  394. "optimized yet. The speed can be slower than "
  395. "non-quantized models.")
  396. if self.quantization == "deepspeedfp" and self.deepspeed_fp_bits \
  397. is None:
  398. raise ValueError(
  399. "deepspeed_fp_bits must be specified when using "
  400. "deepspeedfp quantization.")
  401. if (self.quantization == "awq" and is_hip()
  402. and not envs.APHRODITE_USE_TRITON_AWQ):
  403. logger.warning(
  404. "Using AWQ quantization with ROCm, but "
  405. "APHRODITE_USE_TRITON_AWQ is not set, enabling "
  406. "APHRODITE_USE_TRITON_AWQ.")
  407. envs.APHRODITE_USE_TRITON_AWQ = True
  408. if is_neuron(
  409. ) and self.quantization not in neuron_supported_quantization:
  410. raise ValueError(
  411. f"{self.quantization} quantization is currently not "
  412. f"supported in Neuron Backend.")
  413. def _verify_cuda_graph(self) -> None:
  414. if self.max_seq_len_to_capture is None:
  415. self.max_seq_len_to_capture = self.max_model_len
  416. self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
  417. self.max_model_len)
  418. def verify_async_output_proc(self, parallel_config, speculative_config,
  419. device_config) -> None:
  420. if not self.use_async_output_proc:
  421. # Nothing to check
  422. return
  423. if parallel_config.pipeline_parallel_size > 1:
  424. logger.warning("Async output processing can not be enabled "
  425. "with pipeline parallel")
  426. self.use_async_output_proc = False
  427. return
  428. if device_config.device_type not in ("cuda", "tpu"):
  429. logger.warning(
  430. "Async output processing is only supported for CUDA or TPU. "
  431. "Disabling it for other platforms.")
  432. self.use_async_output_proc = False
  433. return
  434. if envs.APHRODITE_USE_RAY_SPMD_WORKER:
  435. logger.warning(
  436. "Async output processing can not be enabled with ray spmd")
  437. self.use_async_output_proc = False
  438. return
  439. if device_config.device_type == "cuda" and self.enforce_eager:
  440. logger.warning(
  441. "To see benefits of async output processing, enable CUDA "
  442. "graph. Since, enforce-eager is enabled, async output "
  443. "processor cannot be used")
  444. self.use_async_output_proc = not self.enforce_eager
  445. return
  446. # Async postprocessor is not necessary with embedding mode
  447. # since there is no token generation
  448. if self.embedding_mode:
  449. self.use_async_output_proc = False
  450. if speculative_config:
  451. logger.warning("Async output processing is not supported with"
  452. " speculative decoding currently.")
  453. self.use_async_output_proc = False
  454. def verify_with_parallel_config(
  455. self,
  456. parallel_config: "ParallelConfig",
  457. ) -> None:
  458. total_num_attention_heads = getattr(self.hf_text_config,
  459. "num_attention_heads", 0)
  460. tensor_parallel_size = parallel_config.tensor_parallel_size
  461. if (total_num_attention_heads % tensor_parallel_size != 0
  462. and self.quantization is not None):
  463. raise ValueError(
  464. f"Total number of attention heads "
  465. f"({total_num_attention_heads})"
  466. " must be divisible by tensor parallel size "
  467. f"({tensor_parallel_size}) when quantization is used.")
  468. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  469. architectures = getattr(self.hf_config, "architectures", [])
  470. if not all(arch in _PP_SUPPORTED_MODELS
  471. for arch in architectures) and pipeline_parallel_size > 1:
  472. raise NotImplementedError(
  473. "Pipeline parallelism is only supported for the following "
  474. f" architectures: {_PP_SUPPORTED_MODELS}.")
  475. if self.quantization == "bitsandbytes" and self.enforce_eager is False:
  476. logger.warning("CUDA graph is not supported on BitAndBytes yet, "
  477. "fallback to the eager mode.")
  478. self.enforce_eager = True
  479. if pipeline_parallel_size > 1 and self.use_async_output_proc:
  480. logger.warning("Async output processor is not supported with "
  481. "pipeline parallelism currently. Disabling it.")
  482. self.use_async_output_proc = False
  483. def is_attention_free(self) -> bool:
  484. """Returns True if the model has no attention, i.e. the model has no
  485. state that grows with the size of the context.
  486. """
  487. # Return true if the model is mamba.
  488. # This check should be augmented with more models in the future,
  489. # and made more robust if possible.
  490. if hasattr(self.hf_text_config,
  491. "model_type") and self.hf_text_config.model_type == 'mamba':
  492. return True
  493. return False
  494. def get_hf_config_sliding_window(
  495. self) -> Union[Optional[int], List[Optional[int]]]:
  496. """Get the sliding window size, or None if disabled.
  497. """
  498. # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
  499. # addition to sliding window size. We check if that field is present
  500. # and if it's False, return None.
  501. if (hasattr(self.hf_text_config, "use_sliding_window")
  502. and not self.hf_text_config.use_sliding_window):
  503. return None
  504. return getattr(self.hf_text_config, "sliding_window", None)
  505. def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]:
  506. """Get the sliding window size, or None if disabled.
  507. """
  508. # If user disables sliding window, return None.
  509. if self.disable_sliding_window:
  510. return None
  511. # Otherwise get the value from the hf config.
  512. return self.get_hf_config_sliding_window()
  513. def get_vocab_size(self) -> int:
  514. return self.hf_text_config.vocab_size
  515. def get_hidden_size(self) -> int:
  516. return self.hf_text_config.hidden_size
  517. def get_head_size(self) -> int:
  518. # TODO remove hard code
  519. spec_model_types = ["medusa", "mlp_speculator"]
  520. if hasattr(self.hf_text_config, "model_type"
  521. ) and self.hf_text_config.model_type == 'deepseek_v2':
  522. # FlashAttention supports only head_size 32, 64, 128, 256,
  523. # we need to pad head_size 192 to 256
  524. return 256
  525. if self.is_attention_free() or \
  526. self.hf_text_config.model_type in spec_model_types:
  527. return 0
  528. if hasattr(self.hf_text_config, "head_dim"):
  529. return self.hf_text_config.head_dim
  530. # FIXME: This may not be true for all models.
  531. return (self.hf_text_config.hidden_size //
  532. self.hf_text_config.num_attention_heads)
  533. def get_total_num_kv_heads(self) -> int:
  534. """Returns the total number of KV heads."""
  535. # For GPTBigCode & Falcon:
  536. # NOTE: for falcon, when new_decoder_architecture is True, the
  537. # multi_query flag is ignored and we use n_head_kv for the number of
  538. # KV heads.
  539. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  540. new_decoder_arch_falcon = (
  541. self.hf_config.model_type in falcon_model_types
  542. and getattr(self.hf_config, "new_decoder_architecture", False))
  543. if not new_decoder_arch_falcon and getattr(self.hf_text_config,
  544. "multi_query", False):
  545. # Multi-query attention, only one KV head.
  546. # Currently, tensor parallelism is not supported in this case.
  547. return 1
  548. # For DBRX and MPT
  549. if self.hf_config.model_type == "mpt":
  550. if "kv_n_heads" in self.hf_config.attn_config:
  551. return self.hf_config.attn_config["kv_n_heads"]
  552. return self.hf_config.num_attention_heads
  553. if self.hf_config.model_type == "dbrx":
  554. return getattr(self.hf_config.attn_config, "kv_n_heads",
  555. self.hf_config.num_attention_heads)
  556. if self.is_attention_free():
  557. return 0
  558. attributes = [
  559. # For Falcon:
  560. "n_head_kv",
  561. "num_kv_heads",
  562. # For LLaMA-2:
  563. "num_key_value_heads",
  564. # For ChatGLM:
  565. "multi_query_group_num",
  566. ]
  567. for attr in attributes:
  568. num_kv_heads = getattr(self.hf_text_config, attr, None)
  569. if num_kv_heads is not None:
  570. return num_kv_heads
  571. # For non-grouped-query attention models, the number of KV heads is
  572. # equal to the number of attention heads.
  573. return self.hf_text_config.num_attention_heads
  574. def get_num_kv_heads(self,
  575. parallel_config: "ParallelConfig",
  576. tp_rank: int = 0) -> int:
  577. """Returns the number of KV heads per GPU."""
  578. total_num_kv_heads = self.get_total_num_kv_heads()
  579. # If tensor parallelism is used, we divide the number of KV heads by
  580. # the tensor parallel size. We will replicate the KV heads in the
  581. # case where the number of KV heads is smaller than the tensor
  582. # parallel size so each GPU has at least one KV head.
  583. result = get_current_tp_rank_partition_size(
  584. total_num_kv_heads, tp_rank, parallel_config.tensor_parallel_size)
  585. return max(1, result)
  586. def get_num_attention_heads(self,
  587. parallel_config: "ParallelConfig",
  588. tp_rank: int = 0) -> int:
  589. if getattr(self.hf_text_config, "num_attention_heads", None) is None:
  590. return 0
  591. num_total_kv_heads = self.get_total_num_kv_heads()
  592. num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank)
  593. num_total_attention_heads = self.hf_text_config.num_attention_heads
  594. num_heads_per_kv_head = num_total_attention_heads // num_total_kv_heads
  595. # For GQA attention we make sure the whole attention head group is
  596. # together on the same GPU.
  597. return num_kv_heads * num_heads_per_kv_head
  598. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  599. from aphrodite.distributed.utils import get_pp_indices
  600. total_num_hidden_layers = getattr(self.hf_text_config,
  601. "num_hidden_layers", 0)
  602. pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
  603. pp_size = parallel_config.pipeline_parallel_size
  604. start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
  605. return end - start
  606. def contains_seqlen_agnostic_layers(
  607. self, parallel_config: "ParallelConfig") -> bool:
  608. """True for Mamba/SSM models (Jamba)"""
  609. return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
  610. def get_layers_block_type(self,
  611. parallel_config: "ParallelConfig") -> List[str]:
  612. num_layers = self.get_num_layers(parallel_config)
  613. if self.is_attention_free():
  614. assert (self.hf_config.model_type == "mamba")
  615. return ["mamba"] * num_layers
  616. # Transformers supports layers_block_type @property
  617. return getattr(self.hf_config, "layers_block_type",
  618. ["attention"] * num_layers)
  619. def get_num_attention_layers(self,
  620. parallel_config: "ParallelConfig") -> int:
  621. return len([
  622. t for t in self.get_layers_block_type(parallel_config)
  623. if t == "attention"
  624. ])
  625. def _get_num_seqlen_agnostic_layers(
  626. self, parallel_config: "ParallelConfig") -> int:
  627. return len([
  628. t for t in self.get_layers_block_type(parallel_config)
  629. if t != "attention"
  630. ])
  631. def get_multimodal_config(self) -> "MultiModalConfig":
  632. """
  633. Get the multimodal configuration of the model.
  634. Raises:
  635. ValueError: If the model is not multimodal.
  636. """
  637. if self.multimodal_config is None:
  638. raise ValueError("The model is not multimodal.")
  639. return self.multimodal_config
  640. @property
  641. def is_encoder_decoder_model(self) -> bool:
  642. """Extract the HF encoder/decoder model flag."""
  643. return getattr(self.hf_config, "is_encoder_decoder", False)
  644. @property
  645. def is_embedding_model(self) -> bool:
  646. """Extract the embedding model flag."""
  647. return self.embedding_mode
  648. @property
  649. def is_multimodal_model(self) -> bool:
  650. return self.multimodal_config is not None
  651. class CacheConfig:
  652. """Configuration for the KV cache.
  653. Args:
  654. block_size: Size of a cache block in number of tokens.
  655. gpu_memory_utilization: Fraction of GPU memory to use for the
  656. Aphrodite execution.
  657. swap_space: Size of the CPU swap space per GPU (in GiB).
  658. cache_dtype: Data type for kv cache storage.
  659. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
  660. profiled num_gpu_blocks if specified. Does nothing if None.
  661. """
  662. def __init__(
  663. self,
  664. block_size: int,
  665. gpu_memory_utilization: float,
  666. swap_space: float,
  667. cache_dtype: str,
  668. is_attention_free: bool = False,
  669. num_gpu_blocks_override: Optional[int] = None,
  670. sliding_window: Optional[int] = None,
  671. enable_prefix_caching: bool = False,
  672. cpu_offload_gb: float = 0.0,
  673. ) -> None:
  674. self.block_size = block_size
  675. self.gpu_memory_utilization = gpu_memory_utilization
  676. self.swap_space_bytes = swap_space * GiB_bytes
  677. self.num_gpu_blocks_override = num_gpu_blocks_override
  678. self.cache_dtype = cache_dtype
  679. self.is_attention_free = is_attention_free
  680. self.sliding_window = sliding_window
  681. self.enable_prefix_caching = enable_prefix_caching
  682. self.cpu_offload_gb = cpu_offload_gb
  683. self._verify_args()
  684. self._verify_cache_dtype()
  685. self._verify_prefix_caching()
  686. # Will be set after profiling.
  687. self.num_gpu_blocks = None
  688. self.num_cpu_blocks = None
  689. def metrics_info(self):
  690. # convert cache_config to dict(key: str, value: str) for prometheus
  691. # metrics info
  692. return {key: str(value) for key, value in self.__dict__.items()}
  693. def _verify_args(self) -> None:
  694. if self.gpu_memory_utilization > 1.0:
  695. raise ValueError(
  696. "GPU memory utilization must be less than 1.0. Got "
  697. f"{self.gpu_memory_utilization}.")
  698. def _verify_cache_dtype(self) -> None:
  699. if self.cache_dtype == "auto":
  700. pass
  701. elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
  702. logger.info(
  703. "Using fp8 data type to store kv cache. It reduces the GPU "
  704. "memory footprint and boosts the performance. "
  705. "Meanwhile, it may cause accuracy drop without a proper "
  706. "scaling factor")
  707. else:
  708. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  709. def _verify_prefix_caching(self) -> None:
  710. if not self.enable_prefix_caching:
  711. return
  712. if self.sliding_window is not None:
  713. raise NotImplementedError(
  714. "Prefix caching is not supported with sliding window. "
  715. "Run with --disable-sliding-window to use prefix caching.")
  716. if self.cache_dtype == "fp8":
  717. capability = current_platform.get_device_capability()
  718. capability = capability[0] * 10 + capability[1]
  719. if capability < 89:
  720. raise NotImplementedError(
  721. "FP8 KV cache with prefix caching is only supported on "
  722. "GPUs with compute capability 8.9 or higher (e.g., "
  723. "4090, H100). Your GPU has compute capability "
  724. f"{capability}")
  725. if not HAS_TRITON and self.enable_prefix_caching:
  726. raise ValueError("Triton is not installed, "
  727. "prefix caching will not work.")
  728. def verify_with_parallel_config(
  729. self,
  730. parallel_config: "ParallelConfig",
  731. ) -> None:
  732. total_cpu_memory = get_cpu_memory()
  733. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  734. # group are in the same node. However, the GPUs may span multiple nodes.
  735. num_gpus_per_node = parallel_config.tensor_parallel_size
  736. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  737. msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
  738. f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
  739. "is allocated for the swap space.")
  740. if cpu_memory_usage > 0.7 * total_cpu_memory:
  741. raise ValueError("Too large swap space. " + msg)
  742. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  743. logger.warning("Possibly too large swap space. " + msg)
  744. @dataclass
  745. class TokenizerPoolConfig:
  746. """Configuration for the tokenizer pool.
  747. Args:
  748. pool_size: Number of tokenizer workers in the pool.
  749. pool_type: Type of the pool.
  750. extra_config: Additional config for the pool.
  751. The way the config will be used depends on the
  752. pool type.
  753. """
  754. pool_size: int
  755. pool_type: Union[str, Type["BaseTokenizerGroup"]]
  756. extra_config: dict
  757. def __post_init__(self):
  758. if self.pool_type not in ("ray", ) and not isinstance(
  759. self.pool_type, type):
  760. raise ValueError(f"Unknown pool type: {self.pool_type}")
  761. if not isinstance(self.extra_config, dict):
  762. raise ValueError("extra_config must be a dictionary.")
  763. @classmethod
  764. def create_config(
  765. cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
  766. tokenizer_pool_extra_config: Optional[Union[str, dict]]
  767. ) -> Optional["TokenizerPoolConfig"]:
  768. """Create a TokenizerPoolConfig from the given parameters.
  769. If tokenizer_pool_size is 0, return None.
  770. Args:
  771. tokenizer_pool_size: Number of tokenizer workers in the pool.
  772. tokenizer_pool_type: Type of the pool.
  773. tokenizer_pool_extra_config: Additional config for the pool.
  774. The way the config will be used depends on the
  775. pool type. This can be a JSON string (will be parsed).
  776. """
  777. if tokenizer_pool_size:
  778. if isinstance(tokenizer_pool_extra_config, str):
  779. tokenizer_pool_extra_config_parsed = json.loads(
  780. tokenizer_pool_extra_config)
  781. else:
  782. tokenizer_pool_extra_config_parsed = (
  783. tokenizer_pool_extra_config or {})
  784. tokenizer_pool_config = cls(tokenizer_pool_size,
  785. tokenizer_pool_type,
  786. tokenizer_pool_extra_config_parsed)
  787. else:
  788. tokenizer_pool_config = None
  789. return tokenizer_pool_config
  790. class LoadFormat(str, enum.Enum):
  791. AUTO = "auto"
  792. PT = "pt"
  793. SAFETENSORS = "safetensors"
  794. NPCACHE = "npcache"
  795. DUMMY = "dummy"
  796. TENSORIZER = "tensorizer"
  797. SHARDED_STATE = "sharded_state"
  798. GGUF = "gguf"
  799. BITSANDBYTES = "bitsandbytes"
  800. MISTRAL = "mistral"
  801. @dataclass
  802. class LoadConfig:
  803. """
  804. download_dir: Directory to download and load the weights, default to the
  805. default cache directory of huggingface.
  806. load_format: The format of the model weights to load:
  807. "auto" will try to load the weights in the safetensors format and
  808. fall back to the pytorch bin format if safetensors format is
  809. not available.
  810. "pt" will load the weights in the pytorch bin format.
  811. "safetensors" will load the weights in the safetensors format.
  812. "npcache" will load the weights in pytorch format and store
  813. a numpy cache to speed up the loading.
  814. "dummy" will initialize the weights with random values, which is
  815. mainly for profiling.
  816. "tensorizer" will use CoreWeave's tensorizer library for
  817. fast weight loading.
  818. ignore_patterns: The list of patterns to ignore when loading the model.
  819. Default to "original/**/*" to avoid repeated loading of llama's
  820. checkpoints.
  821. """
  822. load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
  823. download_dir: Optional[str] = None
  824. model_loader_extra_config: Optional[Union[str, dict]] = field(
  825. default_factory=dict)
  826. ignore_patterns: Optional[Union[List[str], str]] = None
  827. def __post_init__(self):
  828. model_loader_extra_config = self.model_loader_extra_config or {}
  829. if isinstance(model_loader_extra_config, str):
  830. self.model_loader_extra_config = json.loads(
  831. model_loader_extra_config)
  832. self._verify_load_format()
  833. if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
  834. logger.info(
  835. "Ignoring the following patterns when downloading weights: "
  836. f"{self.ignore_patterns}")
  837. else:
  838. self.ignore_patterns = ["original/**/*"]
  839. def _verify_load_format(self) -> None:
  840. if not isinstance(self.load_format, str):
  841. return
  842. load_format = self.load_format.lower()
  843. self.load_format = LoadFormat(load_format)
  844. rocm_not_supported_load_format: List[str] = []
  845. if is_hip() and load_format in rocm_not_supported_load_format:
  846. rocm_supported_load_format = [
  847. f for f in LoadFormat.__members__
  848. if (f not in rocm_not_supported_load_format)
  849. ]
  850. raise ValueError(
  851. f"load format '{load_format}' is not supported in ROCm. "
  852. f"Supported load formats are "
  853. f"{rocm_supported_load_format}")
  854. class ParallelConfig:
  855. """Configuration for the distributed execution.
  856. Args:
  857. pipeline_parallel_size: Number of pipeline parallel groups.
  858. tensor_parallel_size: Number of tensor parallel groups.
  859. worker_use_ray: Deprecated, use distributed_executor_backend instead.
  860. max_parallel_loading_workers: Maximum number of multiple batches
  861. when load model sequentially. To avoid RAM OOM when using tensor
  862. parallel and large models.
  863. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  864. fall back to NCCL.
  865. tokenizer_pool_config: Config for the tokenizer pool.
  866. If None, will use synchronous tokenization.
  867. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
  868. https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
  869. placement_group: ray distributed model workers placement group.
  870. distributed_executor_backend: Backend to use for distributed model
  871. workers, either "ray" or "mp" (multiprocessing). If either
  872. pipeline_parallel_size or tensor_parallel_size is greater than 1,
  873. will default to "ray" if Ray is installed or "mp" otherwise.
  874. """
  875. def __init__(
  876. self,
  877. pipeline_parallel_size: int,
  878. tensor_parallel_size: int,
  879. worker_use_ray: Optional[bool] = None,
  880. max_parallel_loading_workers: Optional[int] = None,
  881. disable_custom_all_reduce: bool = False,
  882. tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
  883. ray_workers_use_nsight: bool = False,
  884. placement_group: Optional["PlacementGroup"] = None,
  885. distributed_executor_backend: Optional[Union[
  886. str, Type["ExecutorBase"]]] = None,
  887. ) -> None:
  888. self.pipeline_parallel_size = pipeline_parallel_size
  889. self.tensor_parallel_size = tensor_parallel_size
  890. self.distributed_executor_backend = distributed_executor_backend
  891. self.max_parallel_loading_workers = max_parallel_loading_workers
  892. self.disable_custom_all_reduce = disable_custom_all_reduce
  893. self.tokenizer_pool_config = tokenizer_pool_config
  894. self.ray_workers_use_nsight = ray_workers_use_nsight
  895. self.placement_group = placement_group
  896. self.world_size = pipeline_parallel_size * self.tensor_parallel_size
  897. if worker_use_ray:
  898. if self.distributed_executor_backend is None:
  899. self.distributed_executor_backend = "ray"
  900. elif not self.use_ray:
  901. raise ValueError(f"worker-use-ray can't be used with "
  902. f"distributed executor backend "
  903. f"'{self.distributed_executor_backend}'.")
  904. if current_platform.is_tpu() and self.world_size > 1:
  905. if self.distributed_executor_backend is None:
  906. self.distributed_executor_backend = "ray"
  907. if self.distributed_executor_backend != "ray":
  908. raise ValueError(
  909. "TPU backend only supports Ray for distributed inference.")
  910. if self.distributed_executor_backend is None and self.world_size > 1:
  911. # We use multiprocessing by default if world_size fits on the
  912. # current node and we aren't in a ray placement group.
  913. from aphrodite.executor import ray_utils
  914. backend = "mp"
  915. ray_found = ray_utils.ray_is_available()
  916. if not is_cpu() and cuda_device_count_stateless() < self.world_size:
  917. if not ray_found:
  918. raise ValueError("Unable to load Ray which is "
  919. "required for multi-node inference, "
  920. "please install Ray with `pip install "
  921. "ray`.") from ray_utils.ray_import_err
  922. backend = "ray"
  923. elif ray_found:
  924. if self.placement_group:
  925. backend = "ray"
  926. else:
  927. from ray import is_initialized as ray_is_initialized
  928. if ray_is_initialized():
  929. from ray.util import get_current_placement_group
  930. if get_current_placement_group():
  931. backend = "ray"
  932. self.distributed_executor_backend = backend
  933. logger.info(
  934. f"Defaulting to use {backend} for distributed inference.")
  935. self._verify_args()
  936. self.rank = 0
  937. @property
  938. def use_ray(self) -> bool:
  939. return self.distributed_executor_backend == "ray" or (
  940. isinstance(self.distributed_executor_backend, type)
  941. and self.distributed_executor_backend.uses_ray)
  942. def _verify_args(self) -> None:
  943. # Lazy import to avoid circular import
  944. from aphrodite.executor.executor_base import ExecutorBase
  945. if self.distributed_executor_backend not in (
  946. "ray", "mp", None) and not (isinstance(
  947. self.distributed_executor_backend, type) and issubclass(
  948. self.distributed_executor_backend, ExecutorBase)):
  949. raise ValueError(
  950. "Unrecognized distributed executor backend "
  951. f"{self.distributed_executor_backend}. Supported "
  952. "values are 'ray', 'mp' or custom ExecutorBase subclass.")
  953. if self.use_ray:
  954. from aphrodite.executor import ray_utils
  955. ray_utils.assert_ray_available()
  956. if is_hip():
  957. self.disable_custom_all_reduce = True
  958. logger.info(
  959. "Disabled the custom all-reduce kernel because it is not "
  960. "supported on AMD GPUs.")
  961. if self.ray_workers_use_nsight and not self.use_ray:
  962. raise ValueError("Unable to use nsight profiling unless workers "
  963. "run with Ray.")
  964. class SchedulerConfig:
  965. """Scheduler configuration.
  966. Args:
  967. max_num_batched_tokens: Maximum number of tokens to be processed in
  968. a single iteration.
  969. max_num_seqs: Maximum number of sequences to be processed in a single
  970. iteration.
  971. max_model_len: Maximum length of a sequence (including prompt
  972. and generated text).
  973. is_attention_free: True if the running model does not have state that
  974. grows as the context size increases.
  975. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
  976. num_lookahead_slots: The number of slots to allocate per sequence per
  977. step, beyond the known token ids. This is used in speculative
  978. decoding to store KV activations of tokens which may or may not be
  979. accepted.
  980. delay_factor: Apply a delay (of delay factor multiplied by previous
  981. prompt latency) before scheduling next prompt.
  982. enable_chunked_prefill: If True, prefill requests can be chunked based
  983. on the remaining max_num_batched_tokens.
  984. embedding_mode: Whether the running model is for embedding.
  985. preemption_mode: Whether to perform preemption by swapping or
  986. recomputation. If not specified, we determine the mode as follows:
  987. We use recomputation by default since it incurs lower overhead than
  988. swapping. However, when the sequence group has multiple sequences
  989. (e.g., beam search), recomputation is not currently supported. In
  990. such a case, we use swapping instead.
  991. send_delta_data: Private API. If used, scheduler sends delta data to
  992. workers instead of an entire data. It should be enabled only
  993. when SPMD worker architecture is enabled. I.e.,
  994. APHRODITE_USE_RAY_SPMD_WORKER=1
  995. single_user_mode: If True, we only allocate blocks for one sequence
  996. and use the maximum sequence length as the number of tokens.
  997. """
  998. def __init__(self,
  999. max_num_batched_tokens: Optional[int],
  1000. max_num_seqs: int,
  1001. max_model_len: int,
  1002. cache_config: Optional["CacheConfig"] = None,
  1003. is_attention_free: bool = False,
  1004. use_v2_block_manager: bool = False,
  1005. num_lookahead_slots: int = 0,
  1006. delay_factor: float = 0.0,
  1007. enable_chunked_prefill: bool = False,
  1008. embedding_mode: bool = False,
  1009. is_multimodal_model: bool = False,
  1010. preemption_mode: Optional[str] = None,
  1011. num_scheduler_steps: int = 1,
  1012. send_delta_data: bool = False,
  1013. single_user_mode: bool = False) -> None:
  1014. if max_num_batched_tokens is None:
  1015. if enable_chunked_prefill:
  1016. # It is the values that have the best balance between ITL
  1017. # and TTFT on A100. Note it is not optimized for throughput.
  1018. max_num_batched_tokens = 512
  1019. else:
  1020. # If max_model_len is too short, use 2048 as the default value
  1021. # for higher throughput.
  1022. max_num_batched_tokens = max(max_model_len, 2048)
  1023. if embedding_mode:
  1024. # For embedding, choose specific value for higher throughput
  1025. max_num_batched_tokens = max(
  1026. max_num_batched_tokens,
  1027. _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
  1028. )
  1029. if is_multimodal_model:
  1030. # The value needs to be at least the number of multimodal tokens
  1031. max_num_batched_tokens = max(
  1032. max_num_batched_tokens,
  1033. _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
  1034. )
  1035. self.max_num_batched_tokens = max_num_batched_tokens
  1036. if enable_chunked_prefill:
  1037. logger.info(
  1038. "Chunked prefill is enabled with "
  1039. f"max_num_batched_tokens={self.max_num_batched_tokens}.")
  1040. if single_user_mode:
  1041. max_num_seqs = 1
  1042. if cache_config.enable_prefix_caching:
  1043. if not envs.APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE:
  1044. logger.warning(
  1045. "Chunked prefill is not supported in single user mode, "
  1046. "this is not recommended and may lead to memory "
  1047. "issues. Set APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE=1"
  1048. " to force prefix caching.")
  1049. cache_config.enable_prefix_caching = False
  1050. else:
  1051. logger.warning(
  1052. "Chunked prefill is enabled in single user mode, "
  1053. "this is not recommended and may lead to memory "
  1054. "issues.")
  1055. self.max_num_seqs = max_num_seqs
  1056. self.max_model_len = max_model_len
  1057. self.cache_config = cache_config
  1058. self.is_attention_free = is_attention_free
  1059. self.use_v2_block_manager = use_v2_block_manager
  1060. self.num_lookahead_slots = num_lookahead_slots
  1061. self.delay_factor = delay_factor
  1062. self.chunked_prefill_enabled = enable_chunked_prefill
  1063. self.embedding_mode = embedding_mode
  1064. self.preemption_mode = preemption_mode
  1065. self.num_scheduler_steps = num_scheduler_steps
  1066. self.send_delta_data = send_delta_data
  1067. self.single_user_mode = single_user_mode
  1068. self._verify_args()
  1069. def _verify_args(self) -> None:
  1070. if (self.max_num_batched_tokens < self.max_model_len
  1071. and not self.chunked_prefill_enabled):
  1072. raise ValueError(
  1073. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  1074. f"smaller than max_model_len ({self.max_model_len}). "
  1075. "This effectively limits the maximum sequence length to "
  1076. "max_num_batched_tokens and makes Aphrodite reject longer "
  1077. "sequences. Please increase max_num_batched_tokens or "
  1078. "decrease max_model_len.")
  1079. if self.max_num_batched_tokens < self.max_num_seqs:
  1080. raise ValueError(
  1081. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  1082. "be greater than or equal to max_num_seqs "
  1083. f"({self.max_num_seqs}).")
  1084. if self.num_lookahead_slots < 0:
  1085. raise ValueError(
  1086. "num_lookahead_slots "
  1087. f"({self.num_lookahead_slots}) must be greater than or "
  1088. "equal to 0.")
  1089. if self.num_scheduler_steps < 1:
  1090. raise ValueError(
  1091. "num_scheduler_steps "
  1092. f"({self.num_scheduler_steps}) must be greater than or "
  1093. "equal to 1.")
  1094. @property
  1095. def is_multi_step(self) -> bool:
  1096. return self.num_scheduler_steps > 1
  1097. class DeviceConfig:
  1098. def __init__(self, device: str = "auto") -> None:
  1099. if device == "auto":
  1100. # Automated device type detection
  1101. if is_neuron():
  1102. self.device_type = "neuron"
  1103. elif is_openvino():
  1104. self.device_type = "openvino"
  1105. elif current_platform.is_tpu():
  1106. self.device_type = "tpu"
  1107. elif is_cpu():
  1108. self.device_type = "cpu"
  1109. elif is_xpu():
  1110. self.device_type = "xpu"
  1111. else:
  1112. # We don't call torch.cuda.is_available() here to
  1113. # avoid initializing CUDA before workers are forked
  1114. self.device_type = "cuda"
  1115. else:
  1116. # Device type is assigned explicitly
  1117. self.device_type = device
  1118. # Some device types require processing inputs on CPU
  1119. if self.device_type in ["neuron", "openvino"]:
  1120. self.device = torch.device("cpu")
  1121. elif self.device_type in ["tpu"]:
  1122. self.device = None
  1123. else:
  1124. # Set device with device type
  1125. self.device = torch.device(self.device_type)
  1126. class SpeculativeConfig:
  1127. """Configuration for speculative decoding.
  1128. The configuration is currently specialized to draft-model speculative
  1129. decoding with top-1 proposals.
  1130. """
  1131. @staticmethod
  1132. def maybe_create_spec_config(
  1133. target_model_config: ModelConfig,
  1134. target_parallel_config: ParallelConfig,
  1135. target_dtype: str,
  1136. speculative_model: Optional[str],
  1137. speculative_model_quantization: Optional[str],
  1138. speculative_draft_tensor_parallel_size: Optional[int],
  1139. num_speculative_tokens: Optional[int],
  1140. speculative_max_model_len: Optional[int],
  1141. enable_chunked_prefill: bool,
  1142. use_v2_block_manager: bool,
  1143. disable_log_stats: bool,
  1144. speculative_disable_by_batch_size: Optional[int],
  1145. ngram_prompt_lookup_max: Optional[int],
  1146. ngram_prompt_lookup_min: Optional[int],
  1147. draft_token_acceptance_method: str,
  1148. typical_acceptance_sampler_posterior_threshold: Optional[float],
  1149. typical_acceptance_sampler_posterior_alpha: Optional[float],
  1150. disable_logprobs: Optional[bool],
  1151. ) -> Optional["SpeculativeConfig"]:
  1152. """Create a SpeculativeConfig if possible, else return None.
  1153. This function attempts to create a SpeculativeConfig object based on the
  1154. provided parameters. If the necessary conditions are met, it returns an
  1155. instance of SpeculativeConfig. Otherwise, it returns None.
  1156. Args:
  1157. target_model_config (ModelConfig): The configuration of the target
  1158. model.
  1159. target_parallel_config (ParallelConfig): The parallel configuration
  1160. for the target model.
  1161. target_dtype (str): The data type used for the target model.
  1162. speculative_model (Optional[str]): The name of the speculative
  1163. model, if provided.
  1164. num_speculative_tokens (Optional[int]): The number of speculative
  1165. tokens, if provided. Will default to the number in the draft
  1166. model config if present, otherwise is required.
  1167. speculative_model_quantization (Optional[str]): Quantization method
  1168. that was used to quantize the speculative model weights. If
  1169. None, we assume the model weights are not quantized.
  1170. speculative_draft_tensor_parallel_size (Optional[int]): The degree
  1171. of the tensor parallelism for the draft model.
  1172. speculative_max_model_len (Optional[int]): The maximum model len of
  1173. the speculative model. Used when testing the ability to skip
  1174. speculation for some sequences.
  1175. enable_chunked_prefill (bool): Whether Aphrodite is configured to
  1176. use chunked prefill or not. Used for raising an error since its
  1177. not yet compatible with spec decode.
  1178. use_v2_block_manager (bool): Whether Aphrodite is configured to
  1179. use the v2 block manager or not. Used for raising an error
  1180. since the v2 block manager is required with spec decode.
  1181. speculative_disable_by_batch_size (Optional[int]): Disable
  1182. speculative decoding for new incoming requests when the number
  1183. of enqueue requests is larger than this value, if provided.
  1184. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
  1185. window, if provided.
  1186. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
  1187. window, if provided.
  1188. draft_token_acceptance_method (str): The method to use for
  1189. accepting draft tokens. This can take two possible
  1190. values 'rejection_sampler' and 'typical_acceptance_sampler'
  1191. for RejectionSampler and TypicalAcceptanceSampler
  1192. respectively.
  1193. typical_acceptance_sampler_posterior_threshold (Optional[float]):
  1194. A threshold value that sets a lower bound on the posterior
  1195. probability of a token in the target model for it to be
  1196. accepted. This threshold is used only when we use the
  1197. TypicalAcceptanceSampler for token acceptance.
  1198. typical_acceptance_sampler_posterior_alpha (Optional[float]):
  1199. A scaling factor for the entropy-based threshold in the
  1200. TypicalAcceptanceSampler.
  1201. disable_logprobs (Optional[bool]): If set to True, token log
  1202. probabilities are not returned during speculative decoding.
  1203. If set to False, token log probabilities are returned
  1204. according to the log probability settings in SamplingParams.
  1205. If not specified, it defaults to True.
  1206. Returns:
  1207. Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
  1208. the necessary conditions are met, else None.
  1209. """
  1210. if speculative_model is None:
  1211. if num_speculative_tokens is not None:
  1212. raise ValueError("num_speculative_tokens was provided without "
  1213. "speculative_model.")
  1214. return None
  1215. if (speculative_disable_by_batch_size is not None
  1216. and speculative_disable_by_batch_size < 2):
  1217. raise ValueError("Expected the batch size threshold of disabling "
  1218. "speculative decoding is > 1, but got "
  1219. f"{speculative_disable_by_batch_size=}")
  1220. if enable_chunked_prefill:
  1221. raise ValueError(
  1222. "Speculative decoding and chunked prefill are "
  1223. f"currently mutually exclusive ({enable_chunked_prefill=}).")
  1224. if not use_v2_block_manager:
  1225. raise ValueError(
  1226. "Speculative decoding requires usage of the V2 "
  1227. "block manager. Enable it with --use-v2-block-manager.")
  1228. # TODO: The user should be able to specify revision/max model len
  1229. # for the draft model. It is not currently supported.
  1230. draft_revision = None
  1231. draft_code_revision = None
  1232. draft_quantization = speculative_model_quantization
  1233. if speculative_model == "[ngram]":
  1234. if ngram_prompt_lookup_min is None:
  1235. ngram_prompt_lookup_min = 1
  1236. if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
  1237. raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
  1238. if ngram_prompt_lookup_min < 1:
  1239. raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
  1240. if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
  1241. raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
  1242. f"larger than {ngram_prompt_lookup_max=}")
  1243. # TODO: current we still need extract vocab_size from target model
  1244. # config, in future, we may try refactoring it out, and set
  1245. # draft related config as None here.
  1246. draft_model_config = target_model_config
  1247. draft_parallel_config = target_parallel_config
  1248. else:
  1249. ngram_prompt_lookup_max = 0
  1250. ngram_prompt_lookup_min = 0
  1251. draft_model_config = ModelConfig(
  1252. model=speculative_model,
  1253. tokenizer=target_model_config.tokenizer,
  1254. tokenizer_mode=target_model_config.tokenizer_mode,
  1255. trust_remote_code=target_model_config.trust_remote_code,
  1256. dtype=target_model_config.dtype,
  1257. seed=target_model_config.seed,
  1258. revision=draft_revision,
  1259. code_revision=draft_code_revision,
  1260. tokenizer_revision=target_model_config.tokenizer_revision,
  1261. max_model_len=None,
  1262. spec_target_max_model_len=target_model_config.max_model_len,
  1263. quantization=draft_quantization,
  1264. enforce_eager=target_model_config.enforce_eager,
  1265. max_seq_len_to_capture=target_model_config.
  1266. max_seq_len_to_capture,
  1267. max_logprobs=target_model_config.max_logprobs,
  1268. )
  1269. draft_hf_config = draft_model_config.hf_config
  1270. if (num_speculative_tokens is not None
  1271. and hasattr(draft_hf_config, "num_lookahead_tokens")):
  1272. draft_hf_config.num_lookahead_tokens = num_speculative_tokens
  1273. n_predict = getattr(draft_hf_config, "n_predict", None)
  1274. if n_predict is not None:
  1275. if num_speculative_tokens is None:
  1276. # Default to max value defined in draft model config.
  1277. num_speculative_tokens = n_predict
  1278. elif num_speculative_tokens > n_predict:
  1279. # Verify provided value doesn't exceed the maximum
  1280. # supported by the draft model.
  1281. raise ValueError(
  1282. "This speculative model supports a maximum of "
  1283. f"num_speculative_tokens={n_predict}, but "
  1284. f"{num_speculative_tokens=} was provided.")
  1285. draft_model_config.max_model_len = (
  1286. SpeculativeConfig._maybe_override_draft_max_model_len(
  1287. speculative_max_model_len,
  1288. draft_model_config.max_model_len,
  1289. target_model_config.max_model_len,
  1290. ))
  1291. draft_parallel_config = (
  1292. SpeculativeConfig.create_draft_parallel_config(
  1293. target_parallel_config,
  1294. speculative_draft_tensor_parallel_size))
  1295. if num_speculative_tokens is None:
  1296. raise ValueError(
  1297. "num_speculative_tokens must be provided with "
  1298. "speculative_model unless the draft model config contains an "
  1299. "n_predict parameter.")
  1300. if typical_acceptance_sampler_posterior_threshold is None:
  1301. typical_acceptance_sampler_posterior_threshold = 0.09
  1302. if typical_acceptance_sampler_posterior_alpha is None:
  1303. typical_acceptance_sampler_posterior_alpha = 0.3
  1304. if disable_logprobs is None:
  1305. disable_logprobs = True
  1306. return SpeculativeConfig(
  1307. draft_model_config,
  1308. draft_parallel_config,
  1309. num_speculative_tokens,
  1310. speculative_disable_by_batch_size,
  1311. ngram_prompt_lookup_max,
  1312. ngram_prompt_lookup_min,
  1313. draft_token_acceptance_method=draft_token_acceptance_method,
  1314. typical_acceptance_sampler_posterior_threshold=\
  1315. typical_acceptance_sampler_posterior_threshold,
  1316. typical_acceptance_sampler_posterior_alpha=\
  1317. typical_acceptance_sampler_posterior_alpha,
  1318. disable_logprobs=disable_logprobs,
  1319. disable_log_stats=disable_log_stats,
  1320. )
  1321. @staticmethod
  1322. def _maybe_override_draft_max_model_len(
  1323. speculative_max_model_len: Optional[int],
  1324. draft_max_model_len: int,
  1325. target_max_model_len: int,
  1326. ) -> int:
  1327. """Determine the max sequence len for the draft model. This is usually
  1328. the draft_max_model_len, but may be the target_max_model_len if it is
  1329. less than the draft_max_model_len, or may be speculative_max_model_len
  1330. if it is specified.
  1331. This is necessary so that sequences do not exceed the capacity of the
  1332. draft model or the target model.
  1333. speculative_max_model_len is mainly used for testing that sequences can
  1334. skip speculation.
  1335. """
  1336. if speculative_max_model_len is not None:
  1337. if speculative_max_model_len > draft_max_model_len:
  1338. raise ValueError(f"{speculative_max_model_len=} cannot be "
  1339. f"larger than {draft_max_model_len=}")
  1340. if speculative_max_model_len > target_max_model_len:
  1341. raise ValueError(f"{speculative_max_model_len=} cannot be "
  1342. f"larger than {target_max_model_len=}")
  1343. return speculative_max_model_len
  1344. return min(
  1345. draft_max_model_len,
  1346. target_max_model_len,
  1347. )
  1348. @staticmethod
  1349. def create_draft_parallel_config(
  1350. target_parallel_config: ParallelConfig,
  1351. speculative_draft_tensor_parallel_size: Optional[int]
  1352. ) -> ParallelConfig:
  1353. """Create a parallel config for use by the draft worker.
  1354. This is mostly a copy of the target parallel config, except the tp_size.
  1355. """
  1356. if speculative_draft_tensor_parallel_size is None:
  1357. speculative_draft_tensor_parallel_size = \
  1358. target_parallel_config.tensor_parallel_size
  1359. elif speculative_draft_tensor_parallel_size != 1:
  1360. # TODO: allow tp values larger than 1
  1361. raise ValueError(
  1362. f"{speculative_draft_tensor_parallel_size=} cannot be "
  1363. f"other value than 1")
  1364. draft_parallel_config = ParallelConfig(
  1365. pipeline_parallel_size=target_parallel_config.
  1366. pipeline_parallel_size,
  1367. tensor_parallel_size=speculative_draft_tensor_parallel_size,
  1368. distributed_executor_backend=target_parallel_config.
  1369. distributed_executor_backend,
  1370. max_parallel_loading_workers=target_parallel_config.
  1371. max_parallel_loading_workers,
  1372. disable_custom_all_reduce=target_parallel_config.
  1373. disable_custom_all_reduce,
  1374. tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
  1375. ray_workers_use_nsight=target_parallel_config.
  1376. ray_workers_use_nsight,
  1377. placement_group=target_parallel_config.placement_group,
  1378. )
  1379. return draft_parallel_config
  1380. def __init__(
  1381. self,
  1382. draft_model_config: ModelConfig,
  1383. draft_parallel_config: ParallelConfig,
  1384. num_speculative_tokens: int,
  1385. speculative_disable_by_batch_size: Optional[int],
  1386. ngram_prompt_lookup_max: Optional[int],
  1387. ngram_prompt_lookup_min: Optional[int],
  1388. draft_token_acceptance_method: str,
  1389. typical_acceptance_sampler_posterior_threshold: float,
  1390. typical_acceptance_sampler_posterior_alpha: float,
  1391. disable_logprobs: bool,
  1392. disable_log_stats: bool,
  1393. ):
  1394. """Create a SpeculativeConfig object.
  1395. Args:
  1396. draft_model_config: ModelConfig for the draft model.
  1397. draft_parallel_config: ParallelConfig for the draft model.
  1398. num_speculative_tokens: The number of tokens to sample from the
  1399. draft model before scoring with the target model.
  1400. speculative_disable_by_batch_size: Disable speculative
  1401. decoding for new incoming requests when the number of
  1402. enqueue requests is larger than this value.
  1403. ngram_prompt_lookup_max: Max size of ngram token window.
  1404. ngram_prompt_lookup_min: Min size of ngram token window.
  1405. draft_token_acceptance_method (str): The method to use for
  1406. accepting draft tokens. This can take two possible
  1407. values 'rejection_sampler' and 'typical_acceptance_sampler'
  1408. for RejectionSampler and TypicalAcceptanceSampler
  1409. respectively.
  1410. typical_acceptance_sampler_posterior_threshold (Optional[float]):
  1411. A threshold value that sets a lower bound on the posterior
  1412. probability of a token in the target model for it to be
  1413. accepted. This threshold is used only when we use the
  1414. TypicalAcceptanceSampler for token acceptance.
  1415. typical_acceptance_sampler_posterior_alpha (Optional[float]):
  1416. A scaling factor for the entropy-based threshold in the
  1417. TypicalAcceptanceSampler.
  1418. disable_logprobs: If set to True, token log probabilities will not
  1419. be returned even if requested by sampling parameters. This
  1420. reduces latency by skipping logprob calculation in proposal
  1421. sampling, target sampling, and after accepted tokens are
  1422. determined. If set to False, log probabilities will be
  1423. returned.
  1424. disable_log_stats: Whether to disable periodic printing of stage
  1425. times in speculative decoding.
  1426. """
  1427. self.draft_model_config = draft_model_config
  1428. self.draft_parallel_config = draft_parallel_config
  1429. self.num_speculative_tokens = num_speculative_tokens
  1430. self.speculative_disable_by_batch_size = \
  1431. speculative_disable_by_batch_size
  1432. self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
  1433. self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
  1434. self.draft_token_acceptance_method = draft_token_acceptance_method
  1435. self.typical_acceptance_sampler_posterior_threshold = \
  1436. typical_acceptance_sampler_posterior_threshold
  1437. self.typical_acceptance_sampler_posterior_alpha = \
  1438. typical_acceptance_sampler_posterior_alpha
  1439. self.disable_logprobs = disable_logprobs
  1440. self.disable_log_stats = disable_log_stats
  1441. self._verify_args()
  1442. def _verify_args(self) -> None:
  1443. if self.num_speculative_tokens <= 0:
  1444. raise ValueError("Expected num_speculative_tokens to be greater "
  1445. f"than zero ({self.num_speculative_tokens}).")
  1446. if self.draft_model_config:
  1447. self.draft_model_config.verify_with_parallel_config(
  1448. self.draft_parallel_config)
  1449. # Validate and set draft token acceptance related settings.
  1450. if (self.draft_token_acceptance_method is None):
  1451. raise ValueError("draft_token_acceptance_method is not set. "
  1452. "Expected values are rejection_sampler or "
  1453. "typical_acceptance_sampler.")
  1454. if (self.draft_token_acceptance_method != 'rejection_sampler'
  1455. and self.draft_token_acceptance_method !=
  1456. 'typical_acceptance_sampler'):
  1457. raise ValueError(
  1458. "Expected draft_token_acceptance_method to be either "
  1459. "rejection_sampler or typical_acceptance_sampler. Instead it "
  1460. f"is {self.draft_token_acceptance_method}")
  1461. if (self.typical_acceptance_sampler_posterior_threshold < 0
  1462. or self.typical_acceptance_sampler_posterior_alpha < 0):
  1463. raise ValueError(
  1464. "Expected typical_acceptance_sampler_posterior_threshold "
  1465. "and typical_acceptance_sampler_posterior_alpha to be > 0. "
  1466. "Instead found "
  1467. f"typical_acceptance_sampler_posterior_threshold = "
  1468. f"{self.typical_acceptance_sampler_posterior_threshold} and "
  1469. f"typical_acceptance_sampler_posterior_alpha = "
  1470. f"{self.typical_acceptance_sampler_posterior_alpha}")
  1471. @property
  1472. def num_lookahead_slots(self) -> int:
  1473. """The number of additional slots the scheduler should allocate per
  1474. step, in addition to the slots allocated for each known token.
  1475. This is equal to the number of speculative tokens, as each speculative
  1476. token must be scored.
  1477. """
  1478. return self.num_speculative_tokens
  1479. def __repr__(self) -> str:
  1480. if self.ngram_prompt_lookup_max > 0:
  1481. draft_model = "[ngram]"
  1482. else:
  1483. draft_model = self.draft_model_config.model
  1484. num_spec_tokens = self.num_speculative_tokens
  1485. return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
  1486. @dataclass
  1487. class LoRAConfig:
  1488. max_lora_rank: int
  1489. max_loras: int
  1490. fully_sharded_loras: bool = False
  1491. max_cpu_loras: Optional[int] = None
  1492. lora_dtype: Optional[torch.dtype] = None
  1493. lora_extra_vocab_size: int = 256
  1494. # This is a constant.
  1495. lora_vocab_padding_size: ClassVar[int] = 256
  1496. long_lora_scaling_factors: Optional[Tuple[float]] = None
  1497. def __post_init__(self):
  1498. # Setting the maximum rank to 256 should be able to satisfy the vast
  1499. # majority of applications.
  1500. possible_max_ranks = (8, 16, 32, 64, 128, 256)
  1501. possible_lora_extra_vocab_size = (0, 256, 512)
  1502. if self.max_lora_rank not in possible_max_ranks:
  1503. raise ValueError(
  1504. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  1505. f"{possible_max_ranks}.")
  1506. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  1507. raise ValueError(
  1508. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  1509. f"must be one of {possible_lora_extra_vocab_size}.")
  1510. if self.max_loras < 1:
  1511. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  1512. if self.max_cpu_loras is None:
  1513. self.max_cpu_loras = self.max_loras
  1514. elif self.max_cpu_loras < self.max_loras:
  1515. raise ValueError(
  1516. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  1517. f"max_loras ({self.max_loras})")
  1518. def verify_with_model_config(self, model_config: ModelConfig):
  1519. if self.lora_dtype in (None, "auto"):
  1520. self.lora_dtype = model_config.dtype
  1521. elif isinstance(self.lora_dtype, str):
  1522. self.lora_dtype = getattr(torch, self.lora_dtype)
  1523. if model_config.quantization and model_config.quantization not in [
  1524. "awq", "gptq"
  1525. ]:
  1526. # TODO support all other quants
  1527. logger.warning(f"{model_config.quantization} quantization is not "
  1528. "tested with LoRA yet.")
  1529. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  1530. if scheduler_config.chunked_prefill_enabled:
  1531. logger.warning(
  1532. "Chunked Prefill with LoRA is not rigorously tested.")
  1533. def verify_with_parallel_config(self, parallel_config: ParallelConfig):
  1534. if self.lora_vocab_padding_size % parallel_config.world_size != 0:
  1535. raise ValueError("LoRA vocab padding size must be divisible "
  1536. "by world size.")
  1537. @dataclass
  1538. class PromptAdapterConfig:
  1539. max_prompt_adapters: int
  1540. max_prompt_adapter_token: int
  1541. max_cpu_prompt_adapters: Optional[int] = None
  1542. prompt_adapter_dtype: Optional[torch.dtype] = None
  1543. def __post_init__(self):
  1544. if self.max_prompt_adapters < 1:
  1545. raise ValueError(f"max_prompt_adapters "
  1546. f"({self.max_prompt_adapters}) must be >= 1.")
  1547. if self.max_prompt_adapter_token == 0:
  1548. raise ValueError("max_prompt_adapter_token must be set.")
  1549. if self.max_cpu_prompt_adapters is None:
  1550. self.max_cpu_prompt_adapters = self.max_prompt_adapters
  1551. def verify_with_model_config(self, model_config: ModelConfig):
  1552. if self.prompt_adapter_dtype in (None, "auto"):
  1553. self.prompt_adapter_dtype = model_config.dtype
  1554. elif isinstance(self.prompt_adapter_dtype, str):
  1555. self.prompt_adapter_dtype = getattr(torch,
  1556. self.prompt_adapter_dtype)
  1557. @dataclass
  1558. class MultiModalConfig:
  1559. """Controls the behavior of multimodal models."""
  1560. limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
  1561. """
  1562. The maximum number of multi-modal input instances allowed per prompt
  1563. for each :class:`~aphrodite.multimodal.MultiModalPlugin`.
  1564. """
  1565. # TODO: Add configs to init vision tower or not.
  1566. _STR_DTYPE_TO_TORCH_DTYPE = {
  1567. "half": torch.float16,
  1568. "float16": torch.float16,
  1569. "float": torch.float32,
  1570. "float32": torch.float32,
  1571. "bfloat16": torch.bfloat16,
  1572. }
  1573. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  1574. def _get_and_verify_dtype(
  1575. config: PretrainedConfig,
  1576. dtype: Union[str, torch.dtype],
  1577. ) -> torch.dtype:
  1578. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  1579. # because config.torch_dtype can be None.
  1580. config_dtype = getattr(config, "torch_dtype", None)
  1581. if config_dtype is None:
  1582. config_dtype = torch.float32
  1583. if isinstance(dtype, str):
  1584. dtype = dtype.lower()
  1585. if dtype == "auto":
  1586. if config_dtype == torch.float32:
  1587. if config.model_type == "gemma2":
  1588. logger.info(
  1589. "For Gemma 2, we downcast float32 to bfloat16 instead "
  1590. "of float16 by default. Please specify `dtype` if you "
  1591. "want to use float16.")
  1592. torch_dtype = torch.bfloat16
  1593. else:
  1594. # Following the common practice, we use float16 for float32
  1595. # models.
  1596. torch_dtype = torch.float16
  1597. else:
  1598. torch_dtype = config_dtype
  1599. else:
  1600. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  1601. raise ValueError(f"Unknown dtype: {dtype}")
  1602. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  1603. elif isinstance(dtype, torch.dtype):
  1604. torch_dtype = dtype
  1605. else:
  1606. raise ValueError(f"Unknown dtype: {dtype}")
  1607. if is_hip() and torch_dtype == torch.float32:
  1608. rocm_supported_dtypes = [
  1609. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  1610. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  1611. ]
  1612. raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
  1613. f"Supported dtypes are {rocm_supported_dtypes}")
  1614. # Verify the dtype.
  1615. if torch_dtype != config_dtype:
  1616. if torch_dtype == torch.float32:
  1617. # Upcasting to float32 is allowed.
  1618. pass
  1619. elif config_dtype == torch.float32:
  1620. # Downcasting from float32 to float16 or bfloat16 is allowed.
  1621. pass
  1622. else:
  1623. # Casting between float16 and bfloat16 is allowed with a warning.
  1624. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  1625. return torch_dtype
  1626. def _get_and_verify_max_len(
  1627. hf_config: PretrainedConfig,
  1628. max_model_len: Optional[int],
  1629. disable_sliding_window: bool,
  1630. sliding_window_len: Optional[Union[int, List[Optional[int]]]],
  1631. rope_scaling_arg: Optional[Dict[str, Any]],
  1632. spec_target_max_model_len: Optional[int] = None,
  1633. ) -> int:
  1634. """Get and verify the model's maximum length."""
  1635. derived_max_model_len = float("inf")
  1636. possible_keys = [
  1637. # Cohere: needs to prioritize this over "max_position_embeddings"
  1638. "model_max_length",
  1639. # OPT
  1640. "max_position_embeddings",
  1641. # GPT-2
  1642. "n_positions",
  1643. # MPT
  1644. "max_seq_len",
  1645. # ChatGLM2
  1646. "seq_length",
  1647. # Command-R
  1648. "model_max_length",
  1649. # Others
  1650. "max_sequence_length",
  1651. "max_seq_length",
  1652. "seq_len",
  1653. ]
  1654. # Choose the smallest "max_length" from the possible keys.
  1655. max_len_key = None
  1656. for key in possible_keys:
  1657. max_len = getattr(hf_config, key, None)
  1658. if max_len is not None:
  1659. max_len_key = key if max_len < derived_max_model_len \
  1660. else max_len_key
  1661. derived_max_model_len = min(derived_max_model_len, max_len)
  1662. # If sliding window is manually disabled, max_length should be less
  1663. # than the sliding window length in the model config.
  1664. if disable_sliding_window and sliding_window_len is not None:
  1665. sliding_window_len_min = get_min_sliding_window(sliding_window_len)
  1666. max_len_key = "sliding_window" \
  1667. if sliding_window_len_min < derived_max_model_len else max_len_key
  1668. derived_max_model_len = min(derived_max_model_len,
  1669. sliding_window_len_min)
  1670. # If none of the keys were found in the config, use a default and
  1671. # log a warning.
  1672. if derived_max_model_len == float("inf"):
  1673. if max_model_len is not None:
  1674. # If max_model_len is specified, we use it.
  1675. return max_model_len
  1676. if spec_target_max_model_len is not None:
  1677. # If this is a speculative draft model, we use the max model len
  1678. # from the target model.
  1679. return spec_target_max_model_len
  1680. default_max_len = 2048
  1681. logger.warning(
  1682. "The model's config.json does not contain any of the following "
  1683. "keys to determine the original maximum length of the model: "
  1684. f"{possible_keys}. Assuming the model's maximum length is "
  1685. f"{default_max_len}.")
  1686. derived_max_model_len = default_max_len
  1687. rope_scaling = getattr(hf_config, "rope_scaling", None)
  1688. if rope_scaling is not None:
  1689. rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
  1690. if rope_type not in {"su", "longrope", "llama3"}:
  1691. if disable_sliding_window:
  1692. # TODO: Find a model that supports rope_scaling
  1693. # with sliding window to see if this case should be allowed.
  1694. raise NotImplementedError(
  1695. "Disabling sliding window is not supported for models "
  1696. "with rope_scaling. Please raise an issue so we can "
  1697. "investigate.")
  1698. if rope_type == "mrope":
  1699. scaling_factor = 1
  1700. else:
  1701. assert "factor" in rope_scaling
  1702. scaling_factor = rope_scaling["factor"]
  1703. if rope_type == "yarn":
  1704. derived_max_model_len = rope_scaling[
  1705. "original_max_position_embeddings"]
  1706. derived_max_model_len *= scaling_factor
  1707. # If the user specified a max length, make sure it is smaller than the
  1708. # derived length from the HF model config.
  1709. if max_model_len is None:
  1710. max_model_len = int(derived_max_model_len)
  1711. elif max_model_len > derived_max_model_len:
  1712. # Some models might have a separate key for specifying model_max_length
  1713. # that will be bigger than derived_max_model_len. We compare user input
  1714. # with model_max_length and allow this override when it's smaller.
  1715. model_max_length = getattr(hf_config, "model_max_length", None)
  1716. if envs.APHRODITE_DYNAMIC_ROPE_SCALING:
  1717. scaling_factor = max_model_len / derived_max_model_len
  1718. hf_config.rope_scaling = {"factor": scaling_factor,
  1719. "type": "dynamic"}
  1720. logger.info(
  1721. "Using dynamic RoPE scaling to extend the model's max context "
  1722. f"length from {derived_max_model_len} to {max_model_len}.")
  1723. derived_max_model_len = max_model_len
  1724. elif model_max_length is not None and max_model_len <= model_max_length:
  1725. if disable_sliding_window:
  1726. # TODO: Find a model that has model_max_length
  1727. # with sliding window to see if this case should be allowed.
  1728. raise NotImplementedError(
  1729. "Disabling sliding window is not supported for models "
  1730. "model_max_length in the config. Please raise an issue "
  1731. "so we can investigate.")
  1732. else:
  1733. raise ValueError(
  1734. f"User-specified max_model_len ({max_model_len}) is greater "
  1735. f"than the derived max_model_len ({max_len_key}="
  1736. f"{derived_max_model_len} or model_max_length="
  1737. f"{model_max_length} in model's config.json). To allow "
  1738. "greater lengths, please set the env var "
  1739. "APHRODITE_DYNAMIC_ROPE_SCALING=1")
  1740. return int(max_model_len)
  1741. def get_min_sliding_window(
  1742. sliding_window: Union[int, List[Optional[int]]]) -> int:
  1743. if isinstance(sliding_window, list):
  1744. return min(s for s in sliding_window if s is not None)
  1745. return sliding_window
  1746. def get_served_model_name(model: str,
  1747. served_model_name: Optional[Union[str, List[str]]]):
  1748. """
  1749. If the input is a non-empty list, the first model_name in
  1750. `served_model_name` is taken.
  1751. If the input is a non-empty string, it is used directly.
  1752. For cases where the input is either an empty string or an
  1753. empty list, the fallback is to use `self.model`.
  1754. """
  1755. if not served_model_name:
  1756. return model
  1757. if isinstance(served_model_name, list):
  1758. return served_model_name[0]
  1759. return served_model_name
  1760. @dataclass
  1761. class DecodingConfig:
  1762. """Dataclass which contains the decoding strategy of the engine"""
  1763. # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
  1764. guided_decoding_backend: str = 'lm-format-enforcer'
  1765. def __post_init__(self):
  1766. valid_guided_backends = ['outlines', 'lm-format-enforcer']
  1767. backend = self.guided_decoding_backend
  1768. if backend not in valid_guided_backends:
  1769. raise ValueError(f"Invalid guided_decoding_backend '{backend},"
  1770. f"must be one of {valid_guided_backends}")
  1771. @dataclass(frozen=True)
  1772. class EngineConfig:
  1773. """Dataclass which contains all engine-related configuration. This
  1774. simplifies passing around the distinct configurations in the codebase.
  1775. """
  1776. model_config: ModelConfig
  1777. cache_config: CacheConfig
  1778. parallel_config: ParallelConfig
  1779. scheduler_config: SchedulerConfig
  1780. device_config: DeviceConfig
  1781. load_config: LoadConfig
  1782. lora_config: Optional[LoRAConfig]
  1783. speculative_config: Optional[SpeculativeConfig]
  1784. decoding_config: Optional[DecodingConfig]
  1785. prompt_adapter_config: Optional[PromptAdapterConfig]
  1786. def __post_init__(self):
  1787. """Verify configs are valid & consistent with each other.
  1788. """
  1789. self.model_config.verify_async_output_proc(self.parallel_config,
  1790. self.speculative_config,
  1791. self.device_config)
  1792. self.model_config.verify_with_parallel_config(self.parallel_config)
  1793. self.cache_config.verify_with_parallel_config(self.parallel_config)
  1794. if self.lora_config:
  1795. self.lora_config.verify_with_model_config(self.model_config)
  1796. self.lora_config.verify_with_scheduler_config(
  1797. self.scheduler_config)
  1798. self.lora_config.verify_with_parallel_config(self.parallel_config)
  1799. if self.prompt_adapter_config:
  1800. self.prompt_adapter_config.verify_with_model_config(
  1801. self.model_config)
  1802. def to_dict(self):
  1803. """Return the configs as a dictionary, for use in **kwargs.
  1804. """
  1805. return dict(
  1806. (field.name, getattr(self, field.name)) for field in fields(self))