config.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319
  1. import enum
  2. import json
  3. import os
  4. from dataclasses import dataclass, field, fields
  5. from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
  6. import torch
  7. from loguru import logger
  8. from transformers import PretrainedConfig
  9. from aphrodite.common.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
  10. from aphrodite.modeling.models import ModelRegistry
  11. from aphrodite.quantization import QUANTIZATION_METHODS
  12. from aphrodite.transformers_utils.config import get_config, get_hf_text_config
  13. if TYPE_CHECKING:
  14. from ray.util.placement_group import PlacementGroup
  15. from aphrodite.modeling.model_loader.loader import BaseModelLoader
  16. # If true, will load models from ModelScope instead of Hugging Face Hub.
  17. APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
  18. "False").lower() == "true"
  19. _GB = 1 << 30
  20. _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
  21. class ModelConfig:
  22. """Configuration for the model.
  23. Args:
  24. model: Name or path of the huggingface model to use.
  25. tokenizer: Name or path of the huggingface tokenizer to use.
  26. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  27. available, and "slow" will always use the slow tokenizer.
  28. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  29. downloading the model and tokenizer.
  30. dtype: Data type for model weights and activations. The "auto" option
  31. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  32. for BF16 models.
  33. seed: Random seed for reproducibility.
  34. revision: The specific model version to use. It can be a branch name,
  35. a tag name, or a commit id. If unspecified, will use the default
  36. version.
  37. code_revision: The specific revision to use for the model code on
  38. Hugging Face Hub. It can be a branch name, a tag name, or a
  39. commit id. If unspecified, will use the default version.
  40. rope_scaling: Dictionary containing the scaling configuration for the
  41. RoPE embeddings. When using this flag, don't update
  42. `max_position_embeddings` to the expected new maximum.
  43. tokenizer_revision: The specific tokenizer version to use. It can be a
  44. branch name, a tag name, or a commit id. If unspecified, will use
  45. the default version.
  46. max_model_len: Maximum length of a sequence (including prompt and
  47. output). If None, will be derived from the model.
  48. quantization: Quantization method that was used to quantize the model
  49. weights. If None, we assume the model weights are not quantized.
  50. load_in_4bit: Whether to load the FP16 model in bitsandbytes 4bit
  51. format. Works with AWQ models as well as FP16.
  52. load_in_8bit: Whether to load the FP16 model in 8bit format. Slower
  53. than load_in_smooth in terms of throughput.
  54. load_in_smooth: Whether to load the FP16 model in smoothquant format.
  55. deepspeed_fp_bits: Number of bits to use for DeepSpeed FP quantization.
  56. Supported number of bits are: 4, 6, 8, 12.
  57. quantization_param_path: Path to JSON file containing scaling factors.
  58. Used to load KV cache scaling factors into the model when KV cache
  59. type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
  60. be used to load activation and weight scaling factors when the
  61. model dtype is FP8_E4M3 on ROCm.
  62. enforce_eager: Whether to enforce eager execution. If True, we will
  63. disable CUDA graph and always execute the model in eager mode.
  64. If False, we will use CUDA graph and eager execution in hybrid.
  65. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  66. When a sequence has context length larger than this, we fall back
  67. to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
  68. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
  69. When a sequence has context length larger than this, we fall back
  70. to eager mode
  71. skip_tokenizer_init: If true, skip initialization of tokenizer and
  72. detokenizer.
  73. """
  74. def __init__(
  75. self,
  76. model: str,
  77. tokenizer: str,
  78. tokenizer_mode: str,
  79. trust_remote_code: bool,
  80. dtype: Union[str, torch.dtype],
  81. seed: int,
  82. revision: Optional[str] = None,
  83. code_revision: Optional[str] = None,
  84. rope_scaling: Optional[dict] = None,
  85. tokenizer_revision: Optional[str] = None,
  86. max_model_len: Optional[int] = None,
  87. quantization: Optional[str] = None,
  88. load_in_4bit: bool = False,
  89. load_in_8bit: bool = False,
  90. load_in_smooth: bool = False,
  91. deepspeed_fp_bits: Optional[int] = None,
  92. quantization_param_path: Optional[str] = None,
  93. enforce_eager: bool = True,
  94. max_context_len_to_capture: Optional[int] = None,
  95. max_seq_len_to_capture: Optional[int] = None,
  96. max_logprobs: int = 5,
  97. skip_tokenizer_init: bool = False,
  98. ) -> None:
  99. self.model = model
  100. self.tokenizer = tokenizer
  101. self.tokenizer_mode = tokenizer_mode
  102. self.trust_remote_code = trust_remote_code
  103. self.seed = seed
  104. self.revision = revision
  105. self.code_revision = code_revision
  106. self.rope_scaling = rope_scaling
  107. self.tokenizer_revision = tokenizer_revision
  108. self.quantization = quantization
  109. self.load_in_4bit = load_in_4bit
  110. self.load_in_8bit = load_in_8bit
  111. self.load_in_smooth = load_in_smooth
  112. self.deepspeed_fp_bits = deepspeed_fp_bits
  113. self.quantization_param_path = quantization_param_path
  114. self.enforce_eager = enforce_eager
  115. self.max_context_len_to_capture = max_context_len_to_capture
  116. if self.max_context_len_to_capture is not None:
  117. raise ValueError("`max_context_len_to_capture` is deprecated. "
  118. "Use `max_seq_len_to_capture` instead.")
  119. self.max_seq_len_to_capture = (max_seq_len_to_capture
  120. or max_context_len_to_capture)
  121. self.max_logprobs = max_logprobs
  122. self.skip_tokenizer_init = skip_tokenizer_init
  123. self.hf_config = get_config(self.model, trust_remote_code, revision,
  124. code_revision, rope_scaling)
  125. self.hf_text_config = get_hf_text_config(self.hf_config)
  126. self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
  127. self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
  128. max_model_len)
  129. if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
  130. and getattr(self.hf_config, "rope_scaling", None) is None):
  131. self.hf_config.update({"rope_scaling": {
  132. "type": "extended",
  133. }})
  134. if not self.skip_tokenizer_init:
  135. self._verify_tokenizer_mode()
  136. self._verify_embedding_mode()
  137. self._verify_quantization()
  138. self._verify_cuda_graph()
  139. def _verify_tokenizer_mode(self) -> None:
  140. tokenizer_mode = self.tokenizer_mode.lower()
  141. if tokenizer_mode not in ["auto", "slow"]:
  142. raise ValueError(
  143. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  144. "either 'auto' or 'slow'.")
  145. self.tokenizer_mode = tokenizer_mode
  146. def _verify_embedding_mode(self) -> None:
  147. architectures = getattr(self.hf_config, "architectures", [])
  148. self.embedding_mode = any(
  149. ModelRegistry.is_embedding_model(arch) for arch in architectures)
  150. def _verify_quantization(self) -> None:
  151. supported_quantization = [*QUANTIZATION_METHODS]
  152. rocm_supported_quantization = ["gptq", "squeezellm"]
  153. if self.quantization is not None:
  154. self.quantization = self.quantization.lower()
  155. # Parse quantization method from the HF model config, if available.
  156. quant_cfg = getattr(self.hf_config, "quantization_config", None)
  157. if quant_cfg is not None:
  158. quant_method = quant_cfg.get("quant_method", "").lower()
  159. # Detect which checkpoint is it
  160. for name, method in QUANTIZATION_METHODS.items():
  161. quantization_override = method.override_quantization_method(
  162. quant_cfg, self.quantization)
  163. if quantization_override:
  164. quant_method = quantization_override
  165. self.quantization = quantization_override
  166. break
  167. # Verify quantization configurations.
  168. if self.quantization is None:
  169. self.quantization = quant_method
  170. elif self.quantization != quant_method:
  171. raise ValueError(
  172. "Quantization method specified in the model config "
  173. f"({quant_method}) does not match the quantization "
  174. f"method specified in the `quantization` argument "
  175. f"({self.quantization}).")
  176. if self.load_in_4bit:
  177. # the kernels seem to not work with 4bit weight_only
  178. if torch.cuda.get_device_capability(0)[0] < 8:
  179. raise ValueError(
  180. "load_in_4bit quantization is not supported on GPUs with "
  181. "compute capability less than 8.0.")
  182. if self.quantization is None:
  183. self.quantization = "bnb"
  184. self.hf_config.quantization_config = {
  185. "bits": 4,
  186. "quant_mode": "weight_only",
  187. "quant_method": "bnb",
  188. "group_size": 128,
  189. "zero_point": True,
  190. "from_float": True
  191. }
  192. elif self.quantization == "awq":
  193. logger.warning("AWQ model is being loaded in 4bit bnb format.")
  194. self.quantization = "bnb"
  195. self.hf_config.quantization_config = {
  196. "zero_point": True,
  197. "q_group_size": 128,
  198. "w_bit": 4,
  199. "version": "gemm"
  200. }
  201. elif self.quantization != "bnb":
  202. raise ValueError("4bit quantization is not supported in "
  203. f"{self.quantization}.")
  204. if self.load_in_8bit:
  205. if self.quantization is None:
  206. self.quantization = "bnb"
  207. elif self.quantization != "bnb":
  208. raise ValueError("8bit quantization is not supported in "
  209. f"{self.quantization}.")
  210. self.hf_config.quantization_config = {
  211. "bits": 8,
  212. "quant_mode": "llm_int8",
  213. "quant_method": "bnb",
  214. "group_size": 128,
  215. "zero_point": True,
  216. "from_float": True
  217. }
  218. self.enforce_eager = True
  219. if self.load_in_smooth:
  220. if self.quantization is None:
  221. self.quantization = "bnb"
  222. elif self.quantization != "bnb":
  223. raise ValueError("Smooth quantization is not supported in "
  224. f"{self.quantization}.")
  225. self.hf_config.quantization_config = {
  226. "bits": 8,
  227. "quant_mode": "smoothquant",
  228. "quant_method": "bnb",
  229. "group_size": 128,
  230. "zero_point": True,
  231. "from_float": True
  232. }
  233. self.enforce_eager = True
  234. if self.quantization == "deepspeedfp":
  235. gs = 32 if self.deepspeed_fp_bits == 4 else 128
  236. self.hf_config.quantization_config = {
  237. "bits": self.deepspeed_fp_bits,
  238. "group_size": int(os.environ.get("DEEPSPEED_GROUP_SIZE", gs)),
  239. "quant_method": "deepspeedfp"
  240. }
  241. if self.quantization is not None:
  242. if self.quantization not in supported_quantization:
  243. raise ValueError(
  244. f"Unknown quantization method: {self.quantization}. Must "
  245. f"be one of {supported_quantization}.")
  246. if is_hip(
  247. ) and self.quantization not in rocm_supported_quantization:
  248. raise ValueError(
  249. f"{self.quantization} quantization is currently not "
  250. "supported in ROCm.")
  251. if (self.quantization
  252. not in ["marlin", "gptq_marlin_24", "gptq_marlin"]):
  253. logger.warning(
  254. f"{self.quantization} quantization is not fully "
  255. "optimized yet. The speed can be slower than "
  256. "non-quantized models.")
  257. if self.quantization == "deepspeedfp" and self.deepspeed_fp_bits \
  258. is None:
  259. raise ValueError(
  260. "deepspeed_fp_bits must be specified when using "
  261. "deepspeedfp quantization.")
  262. def _verify_cuda_graph(self) -> None:
  263. if self.max_seq_len_to_capture is None:
  264. self.max_seq_len_to_capture = self.max_model_len
  265. self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
  266. self.max_model_len)
  267. def verify_with_parallel_config(
  268. self,
  269. parallel_config: "ParallelConfig",
  270. ) -> None:
  271. total_num_attention_heads = self.hf_text_config.num_attention_heads
  272. tensor_parallel_size = parallel_config.tensor_parallel_size
  273. if total_num_attention_heads % tensor_parallel_size != 0:
  274. raise ValueError(
  275. f"Total number of attention heads ({total_num_attention_heads})"
  276. " must be divisible by tensor parallel size "
  277. f"({tensor_parallel_size}).")
  278. total_num_hidden_layers = self.hf_text_config.num_hidden_layers
  279. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  280. if total_num_hidden_layers % pipeline_parallel_size != 0:
  281. raise ValueError(
  282. f"Total number of hidden layers ({total_num_hidden_layers}) "
  283. "must be divisible by pipeline parallel size "
  284. f"({pipeline_parallel_size}).")
  285. def get_sliding_window(self) -> Optional[int]:
  286. """Get the sliding window size, or None if disabled.
  287. """
  288. # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
  289. # addition to sliding window size. We check if that field is present
  290. # and if it's False, return None.
  291. if (hasattr(self.hf_text_config, "use_sliding_window")
  292. and not self.hf_text_config.use_sliding_window):
  293. return None
  294. return getattr(self.hf_text_config, "sliding_window", None)
  295. def get_vocab_size(self) -> int:
  296. return self.hf_text_config.vocab_size
  297. def get_hidden_size(self) -> int:
  298. return self.hf_text_config.hidden_size
  299. def get_head_size(self) -> int:
  300. if hasattr(self.hf_text_config, "head_dim"):
  301. return self.hf_text_config.head_dim
  302. # FIXME: This may not be true for all models.
  303. return (self.hf_text_config.hidden_size //
  304. self.hf_text_config.num_attention_heads)
  305. def get_total_num_kv_heads(self) -> int:
  306. """Returns the total number of KV heads."""
  307. # For GPTBigCode & Falcon:
  308. # NOTE: for falcon, when new_decoder_architecture is True, the
  309. # multi_query flag is ignored and we use n_head_kv for the number of
  310. # KV heads.
  311. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  312. new_decoder_arch_falcon = (
  313. self.hf_config.model_type in falcon_model_types
  314. and getattr(self.hf_config, "new_decoder_architecture", False))
  315. if not new_decoder_arch_falcon and getattr(self.hf_text_config,
  316. "multi_query", False):
  317. # Multi-query attention, only one KV head.
  318. # Currently, tensor parallelism is not supported in this case.
  319. return 1
  320. # For DBRX and MPT
  321. if self.hf_config.model_type in ["dbrx", "mpt"]:
  322. return getattr(self.hf_config.attn_config, "kv_n_heads",
  323. self.hf_config.num_attention_heads)
  324. attributes = [
  325. # For Falcon:
  326. "n_head_kv",
  327. "num_kv_heads",
  328. # For LLaMA-2:
  329. "num_key_value_heads",
  330. # For ChatGLM:
  331. "multi_query_group_num",
  332. ]
  333. for attr in attributes:
  334. num_kv_heads = getattr(self.hf_text_config, attr, None)
  335. if num_kv_heads is not None:
  336. return num_kv_heads
  337. # For non-grouped-query attention models, the number of KV heads is
  338. # equal to the number of attention heads.
  339. return self.hf_text_config.num_attention_heads
  340. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  341. """Returns the number of KV heads per GPU."""
  342. total_num_kv_heads = self.get_total_num_kv_heads()
  343. # If tensor parallelism is used, we divide the number of KV heads by
  344. # the tensor parallel size. We will replicate the KV heads in the
  345. # case where the number of KV heads is smaller than the tensor
  346. # parallel size so each GPU has at least one KV head.
  347. return max(1,
  348. total_num_kv_heads // parallel_config.tensor_parallel_size)
  349. def get_num_attention_heads(self,
  350. parallel_config: "ParallelConfig") -> int:
  351. return self.hf_text_config.num_attention_heads // \
  352. parallel_config.tensor_parallel_size
  353. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  354. total_num_hidden_layers = self.hf_text_config.num_hidden_layers
  355. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  356. class CacheConfig:
  357. """Configuration for the KV cache.
  358. Args:
  359. block_size: Size of a cache block in number of tokens.
  360. gpu_memory_utilization: Fraction of GPU memory to use for the
  361. Aphrodite execution.
  362. swap_space: Size of the CPU swap space per GPU (in GiB).
  363. cache_dtype: Data type for kv cache storage.
  364. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
  365. profiled num_gpu_blocks if specified. Does nothing if None.
  366. """
  367. def __init__(
  368. self,
  369. block_size: int,
  370. gpu_memory_utilization: float,
  371. swap_space: int,
  372. cache_dtype: str,
  373. num_gpu_blocks_override: Optional[int] = None,
  374. sliding_window: Optional[int] = None,
  375. enable_prefix_caching: bool = False,
  376. ) -> None:
  377. self.block_size = block_size
  378. self.gpu_memory_utilization = gpu_memory_utilization
  379. self.swap_space_bytes = swap_space * _GB
  380. self.num_gpu_blocks_override = num_gpu_blocks_override
  381. self.cache_dtype = cache_dtype
  382. self.sliding_window = sliding_window
  383. self.enable_prefix_caching = enable_prefix_caching
  384. self._verify_args()
  385. self._verify_cache_dtype()
  386. # Will be set after profiling.
  387. self.num_gpu_blocks = None
  388. self.num_cpu_blocks = None
  389. def metrics_info(self):
  390. # convert cache_config to dict(key: str, value: str) for prometheus
  391. # metrics info
  392. return {key: str(value) for key, value in self.__dict__.items()}
  393. def _verify_args(self) -> None:
  394. if self.gpu_memory_utilization > 1.0:
  395. raise ValueError(
  396. "GPU memory utilization must be less than 1.0. Got "
  397. f"{self.gpu_memory_utilization}.")
  398. def _verify_cache_dtype(self) -> None:
  399. if self.cache_dtype == "auto":
  400. pass
  401. elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
  402. logger.info(
  403. "Using fp8 data type to store kv cache. It reduces the GPU "
  404. "memory footprint and boosts the performance. "
  405. "Meanwhile, it may cause accuracy drop without a proper "
  406. "scaling factor")
  407. else:
  408. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  409. def verify_with_parallel_config(
  410. self,
  411. parallel_config: "ParallelConfig",
  412. ) -> None:
  413. total_cpu_memory = get_cpu_memory()
  414. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  415. # group are in the same node. However, the GPUs may span multiple nodes.
  416. num_gpus_per_node = parallel_config.tensor_parallel_size
  417. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  418. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  419. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  420. "allocated for the swap space.")
  421. if cpu_memory_usage > 0.7 * total_cpu_memory:
  422. raise ValueError("Too large swap space. " + msg)
  423. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  424. logger.warning("Possibly too large swap space. " + msg)
  425. @dataclass
  426. class TokenizerPoolConfig:
  427. """Configuration for the tokenizer pool.
  428. Args:
  429. pool_size: Number of tokenizer workers in the pool.
  430. pool_type: Type of the pool.
  431. extra_config: Additional config for the pool.
  432. The way the config will be used depends on the
  433. pool type.
  434. """
  435. pool_size: int
  436. pool_type: str
  437. extra_config: dict
  438. def __post_init__(self):
  439. if self.pool_type not in ("ray", ):
  440. raise ValueError(f"Unknown pool type: {self.pool_type}")
  441. if not isinstance(self.extra_config, dict):
  442. raise ValueError("extra_config must be a dictionary.")
  443. @classmethod
  444. def create_config(
  445. cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
  446. tokenizer_pool_extra_config: Optional[Union[str, dict]]
  447. ) -> Optional["TokenizerPoolConfig"]:
  448. """Create a TokenizerPoolConfig from the given parameters.
  449. If tokenizer_pool_size is 0, return None.
  450. Args:
  451. tokenizer_pool_size: Number of tokenizer workers in the pool.
  452. tokenizer_pool_type: Type of the pool.
  453. tokenizer_pool_extra_config: Additional config for the pool.
  454. The way the config will be used depends on the
  455. pool type. This can be a JSON string (will be parsed).
  456. """
  457. if tokenizer_pool_size:
  458. if isinstance(tokenizer_pool_extra_config, str):
  459. tokenizer_pool_extra_config_parsed = json.loads(
  460. tokenizer_pool_extra_config)
  461. else:
  462. tokenizer_pool_extra_config_parsed = (
  463. tokenizer_pool_extra_config or {})
  464. tokenizer_pool_config = cls(tokenizer_pool_size,
  465. tokenizer_pool_type,
  466. tokenizer_pool_extra_config_parsed)
  467. else:
  468. tokenizer_pool_config = None
  469. return tokenizer_pool_config
  470. class LoadFormat(str, enum.Enum):
  471. AUTO = "auto"
  472. PT = "pt"
  473. SAFETENSORS = "safetensors"
  474. NPCACHE = "npcache"
  475. DUMMY = "dummy"
  476. TENSORIZER = "tensorizer"
  477. SHARDED_STATE = "sharded_state"
  478. @dataclass
  479. class LoadConfig:
  480. """
  481. download_dir: Directory to download and load the weights, default to the
  482. default cache directory of huggingface.
  483. load_format: The format of the model weights to load:
  484. "auto" will try to load the weights in the safetensors format and
  485. fall back to the pytorch bin format if safetensors format is
  486. not available.
  487. "pt" will load the weights in the pytorch bin format.
  488. "safetensors" will load the weights in the safetensors format.
  489. "npcache" will load the weights in pytorch format and store
  490. a numpy cache to speed up the loading.
  491. "dummy" will initialize the weights with random values, which is
  492. mainly for profiling.
  493. "tensorizer" will use CoreWeave's tensorizer library for
  494. fast weight loading.
  495. """
  496. load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
  497. download_dir: Optional[str] = None
  498. model_loader_extra_config: Optional[Union[str, dict]] = field(
  499. default_factory=dict)
  500. def __post_init__(self):
  501. model_loader_extra_config = self.model_loader_extra_config or {}
  502. if isinstance(model_loader_extra_config, str):
  503. self.model_loader_extra_config = json.loads(
  504. model_loader_extra_config)
  505. self._verify_load_format()
  506. def _verify_load_format(self) -> None:
  507. if not isinstance(self.load_format, str):
  508. return
  509. load_format = self.load_format.lower()
  510. self.load_format = LoadFormat(load_format)
  511. rocm_not_supported_load_format: List[str] = []
  512. if is_hip() and load_format in rocm_not_supported_load_format:
  513. rocm_supported_load_format = [
  514. f for f in LoadFormat.__members__
  515. if (f not in rocm_not_supported_load_format)
  516. ]
  517. raise ValueError(
  518. f"load format '{load_format}' is not supported in ROCm. "
  519. f"Supported load formats are "
  520. f"{rocm_supported_load_format}")
  521. class ParallelConfig:
  522. """Configuration for the distributed execution.
  523. Args:
  524. pipeline_parallel_size: Number of pipeline parallel groups.
  525. tensor_parallel_size: Number of tensor parallel groups.
  526. worker_use_ray: Deprecated, use distributed_executor_backend instead.
  527. max_parallel_loading_workers: Maximum number of multiple batches
  528. when load model sequentially. To avoid RAM OOM when using tensor
  529. parallel and large models.
  530. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  531. fall back to NCCL.
  532. tokenizer_pool_config: Config for the tokenizer pool.
  533. If None, will use synchronous tokenization.
  534. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
  535. https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
  536. placement_group: ray distributed model workers placement group.
  537. distributed_executor_backend: Backend to use for distributed model
  538. workers, either "ray" or "mp" (multiprocessing). If either
  539. pipeline_parallel_size or tensor_parallel_size is greater than 1,
  540. will default to "ray" if Ray is installed or "mp" otherwise.
  541. """
  542. def __init__(
  543. self,
  544. pipeline_parallel_size: int,
  545. tensor_parallel_size: int,
  546. worker_use_ray: Optional[bool] = None,
  547. max_parallel_loading_workers: Optional[int] = None,
  548. disable_custom_all_reduce: bool = False,
  549. tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
  550. ray_workers_use_nsight: bool = False,
  551. placement_group: Optional["PlacementGroup"] = None,
  552. distributed_executor_backend: Optional[str] = None,
  553. ) -> None:
  554. self.pipeline_parallel_size = pipeline_parallel_size
  555. self.tensor_parallel_size = tensor_parallel_size
  556. self.distributed_executor_backend = distributed_executor_backend
  557. self.max_parallel_loading_workers = max_parallel_loading_workers
  558. self.disable_custom_all_reduce = disable_custom_all_reduce
  559. self.tokenizer_pool_config = tokenizer_pool_config
  560. self.ray_workers_use_nsight = ray_workers_use_nsight
  561. self.placement_group = placement_group
  562. self.world_size = pipeline_parallel_size * self.tensor_parallel_size
  563. if worker_use_ray:
  564. if self.distributed_executor_backend is None:
  565. self.distributed_executor_backend = "ray"
  566. elif self.distributed_executor_backend != "ray":
  567. raise ValueError(f"worker-use-ray can't be used with "
  568. f"distributed executor backend "
  569. f"'{self.distributed_executor_backend}'.")
  570. if self.distributed_executor_backend is None and self.world_size > 1:
  571. from aphrodite.executor import ray_utils
  572. ray_found = ray_utils.ray is not None
  573. self.distributed_executor_backend = "ray" if ray_found else "mp"
  574. self._verify_args()
  575. def _verify_args(self) -> None:
  576. if self.pipeline_parallel_size > 1:
  577. raise NotImplementedError(
  578. "Pipeline parallelism is not supported yet.")
  579. if self.distributed_executor_backend not in ("ray", "mp", None):
  580. raise ValueError(
  581. "Unrecognized distributed executor backend. Supported values "
  582. "are 'ray' or 'mp'.")
  583. if not self.disable_custom_all_reduce and self.world_size > 1:
  584. if is_hip():
  585. self.disable_custom_all_reduce = True
  586. logger.info(
  587. "Disabled the custom all-reduce kernel because it is not "
  588. "supported on AMD GPUs.")
  589. elif self.pipeline_parallel_size > 1:
  590. self.disable_custom_all_reduce = True
  591. logger.info(
  592. "Disabled the custom all-reduce kernel because it is not "
  593. "supported with pipeline parallelism.")
  594. if self.ray_workers_use_nsight and (
  595. not self.distributed_executor_backend == "ray"):
  596. raise ValueError("Unable to use nsight profiling unless workers "
  597. "run with Ray.")
  598. class SchedulerConfig:
  599. """Scheduler configuration.
  600. Args:
  601. max_num_batched_tokens: Maximum number of tokens to be processed in
  602. a single iteration.
  603. max_num_seqs: Maximum number of sequences to be processed in a single
  604. iteration.
  605. max_model_len: Maximum length of a sequence (including prompt
  606. and generated text).
  607. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
  608. num_lookahead_slots: The number of slots to allocate per sequence per
  609. step, beyond the known token ids. This is used in speculative
  610. decoding to store KV activations of tokens which may or may not be
  611. accepted.
  612. delay_factor: Apply a delay (of delay factor multiplied by previous
  613. prompt latency) before scheduling next prompt.
  614. enable_chunked_prefill: If True, prefill requests can be chunked based
  615. on the remaining max_num_batched_tokens.
  616. embedding_mode: Whether the running model is for embedding.
  617. """
  618. def __init__(
  619. self,
  620. max_num_batched_tokens: Optional[int],
  621. max_num_seqs: int,
  622. max_model_len: int,
  623. use_v2_block_manager: bool = False,
  624. num_lookahead_slots: int = 0,
  625. delay_factor: float = 0.0,
  626. enable_chunked_prefill: bool = False,
  627. embedding_mode: Optional[bool] = None,
  628. ) -> None:
  629. if max_num_batched_tokens is not None:
  630. self.max_num_batched_tokens = max_num_batched_tokens
  631. else:
  632. if enable_chunked_prefill:
  633. # For chunked prefill, choose the well-tuned batch size.
  634. self.max_num_batched_tokens = 768
  635. elif embedding_mode:
  636. # For embedding, choose specific value for higher throughput
  637. self.max_num_batched_tokens = max(
  638. max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
  639. else:
  640. # If max_model_len is too short, use 2048 as the default value
  641. # for higher throughput.
  642. self.max_num_batched_tokens = max(max_model_len, 2048)
  643. if enable_chunked_prefill:
  644. logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
  645. self.max_num_seqs = max_num_seqs
  646. self.max_model_len = max_model_len
  647. self.use_v2_block_manager = use_v2_block_manager
  648. self.num_lookahead_slots = num_lookahead_slots
  649. self.delay_factor = delay_factor
  650. self.chunked_prefill_enabled = enable_chunked_prefill
  651. self.embedding_mode = embedding_mode
  652. self._verify_args()
  653. def _verify_args(self) -> None:
  654. if (self.max_num_batched_tokens < self.max_model_len
  655. and not self.chunked_prefill_enabled):
  656. raise ValueError(
  657. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  658. f"smaller than max_model_len ({self.max_model_len}). "
  659. "This effectively limits the maximum sequence length to "
  660. "max_num_batched_tokens and makes Aphrodite reject longer "
  661. "sequences. Please increase max_num_batched_tokens or "
  662. "decrease max_model_len.")
  663. if self.max_num_batched_tokens < self.max_num_seqs:
  664. raise ValueError(
  665. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  666. "be greater than or equal to max_num_seqs "
  667. f"({self.max_num_seqs}).")
  668. if self.num_lookahead_slots < 0:
  669. raise ValueError(
  670. "num_lookahead_slots "
  671. f"({self.num_lookahead_slots}) must be greater than or "
  672. "equal to 0.")
  673. class DeviceConfig:
  674. def __init__(self, device: str = "auto") -> None:
  675. if device == "auto":
  676. # Automated device type detection
  677. if is_neuron():
  678. self.device_type = "neuron"
  679. elif is_cpu():
  680. self.device_type = "cpu"
  681. else:
  682. # We don't call torch.cuda.is_available() here to
  683. # avoid initializing CUDA before workers are forked
  684. self.device_type = "cuda"
  685. else:
  686. # Device type is assigned explicitly
  687. self.device_type = device
  688. # Some device types require processing inputs on CPU
  689. if self.device_type in ["neuron"]:
  690. self.device = torch.device("cpu")
  691. else:
  692. # Set device with device type
  693. self.device = torch.device(self.device_type)
  694. class SpeculativeConfig:
  695. """Configuration for speculative decoding.
  696. The configuration is currently specialized to draft-model speculative
  697. decoding with top-1 proposals.
  698. """
  699. @staticmethod
  700. def maybe_create_spec_config(
  701. target_model_config: ModelConfig,
  702. target_parallel_config: ParallelConfig,
  703. target_dtype: str,
  704. speculative_model: Optional[str],
  705. num_speculative_tokens: Optional[int],
  706. speculative_max_model_len: Optional[int],
  707. enable_chunked_prefill: bool,
  708. use_v2_block_manager: bool,
  709. speculative_disable_by_batch_size: Optional[int],
  710. ngram_prompt_lookup_max: Optional[int],
  711. ngram_prompt_lookup_min: Optional[int],
  712. ) -> Optional["SpeculativeConfig"]:
  713. """Create a SpeculativeConfig if possible, else return None.
  714. This function attempts to create a SpeculativeConfig object based on the
  715. provided parameters. If the necessary conditions are met, it returns an
  716. instance of SpeculativeConfig. Otherwise, it returns None.
  717. Args:
  718. target_model_config (ModelConfig): The configuration of the target
  719. model.
  720. target_parallel_config (ParallelConfig): The parallel configuration
  721. for the target model.
  722. target_dtype (str): The data type used for the target model.
  723. speculative_model (Optional[str]): The name of the speculative
  724. model, if provided.
  725. num_speculative_tokens (Optional[int]): The number of speculative
  726. tokens, if provided.
  727. speculative_max_model_len (Optional[int]): The maximum model len of
  728. the speculative model. Used when testing the ability to skip
  729. speculation for some sequences.
  730. enable_chunked_prefill (bool): Whether Aphrodite is configured to
  731. use chunked prefill or not. Used for raising an error since its
  732. not yet compatible with spec decode.
  733. use_v2_block_manager (bool): Whether Aphrodite is configured to
  734. use the v2 block manager or not. Used for raising an error
  735. since the v2 block manager is required with spec decode.
  736. speculative_disable_by_batch_size (Optional[int]): Disable
  737. speculative decoding for new incoming requests when the number
  738. of enqueue requests is larger than this value, if provided.
  739. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
  740. window, if provided.
  741. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
  742. window, if provided.
  743. Returns:
  744. Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
  745. the necessary conditions are met, else None.
  746. """
  747. if speculative_model is None and num_speculative_tokens is None:
  748. return None
  749. if speculative_model is not None and num_speculative_tokens is None:
  750. raise ValueError(
  751. "Expected both speculative_model and "
  752. "num_speculative_tokens to be provided, but found "
  753. f"{speculative_model=} and {num_speculative_tokens=}.")
  754. if (speculative_disable_by_batch_size is not None
  755. and speculative_disable_by_batch_size < 2):
  756. raise ValueError("Expected the batch size threshold of disabling "
  757. "speculative decoding is > 1, but got "
  758. f"{speculative_disable_by_batch_size=}")
  759. assert (speculative_model is not None
  760. and num_speculative_tokens is not None)
  761. if enable_chunked_prefill:
  762. raise ValueError(
  763. "Speculative decoding and chunked prefill are "
  764. f"currently mutually exclusive ({enable_chunked_prefill=}).")
  765. if not use_v2_block_manager:
  766. raise ValueError(
  767. "Speculative decoding requires usage of the V2 "
  768. "block manager. Enable it with --use-v2-block-manager.")
  769. # TODO: The user should be able to specify revision/quantization/max
  770. # model len for the draft model. It is not currently supported.
  771. draft_revision = None
  772. draft_code_revision = None
  773. draft_quantization = None
  774. if speculative_model == "[ngram]":
  775. if ngram_prompt_lookup_min is None:
  776. ngram_prompt_lookup_min = 1
  777. if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
  778. raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
  779. if ngram_prompt_lookup_min < 1:
  780. raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
  781. if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
  782. raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
  783. f"larger than {ngram_prompt_lookup_max=}")
  784. # TODO: current we still need extract vocab_size from target model
  785. # config, in future, we may try refactoring it out, and set
  786. # draft related config as None here.
  787. draft_model_config = target_model_config
  788. draft_parallel_config = target_parallel_config
  789. else:
  790. ngram_prompt_lookup_max = 0
  791. ngram_prompt_lookup_min = 0
  792. draft_model_config = ModelConfig(
  793. model=speculative_model,
  794. tokenizer=target_model_config.tokenizer,
  795. tokenizer_mode=target_model_config.tokenizer_mode,
  796. trust_remote_code=target_model_config.trust_remote_code,
  797. dtype=target_model_config.dtype,
  798. seed=target_model_config.seed,
  799. revision=draft_revision,
  800. code_revision=draft_code_revision,
  801. tokenizer_revision=target_model_config.tokenizer_revision,
  802. max_model_len=None,
  803. quantization=draft_quantization,
  804. enforce_eager=target_model_config.enforce_eager,
  805. max_seq_len_to_capture=target_model_config.
  806. max_seq_len_to_capture,
  807. max_logprobs=target_model_config.max_logprobs,
  808. )
  809. draft_model_config.max_model_len = (
  810. SpeculativeConfig._maybe_override_draft_max_model_len(
  811. speculative_max_model_len,
  812. draft_model_config.max_model_len,
  813. target_model_config.max_model_len,
  814. ))
  815. draft_parallel_config = (
  816. SpeculativeConfig.create_draft_parallel_config(
  817. target_parallel_config))
  818. return SpeculativeConfig(draft_model_config, draft_parallel_config,
  819. num_speculative_tokens,
  820. speculative_disable_by_batch_size,
  821. ngram_prompt_lookup_max,
  822. ngram_prompt_lookup_min)
  823. @staticmethod
  824. def _maybe_override_draft_max_model_len(
  825. speculative_max_model_len: Optional[int],
  826. draft_max_model_len: int,
  827. target_max_model_len: int,
  828. ) -> int:
  829. """Determine the max sequence len for the draft model. This is usually
  830. the draft_max_model_len, but may be the target_max_model_len if it is
  831. less than the draft_max_model_len, or may be speculative_max_model_len
  832. if it is specified.
  833. This is necessary so that sequences do not exceed the capacity of the
  834. draft model or the target model.
  835. speculative_max_model_len is mainly used for testing that sequences can
  836. skip speculation.
  837. """
  838. if speculative_max_model_len is not None:
  839. if speculative_max_model_len > draft_max_model_len:
  840. raise ValueError(f"{speculative_max_model_len=} cannot be "
  841. f"larger than {draft_max_model_len=}")
  842. if speculative_max_model_len > target_max_model_len:
  843. raise ValueError(f"{speculative_max_model_len=} cannot be "
  844. f"larger than {target_max_model_len=}")
  845. return speculative_max_model_len
  846. return min(
  847. draft_max_model_len,
  848. target_max_model_len,
  849. )
  850. @staticmethod
  851. def create_draft_parallel_config(
  852. target_parallel_config: ParallelConfig) -> ParallelConfig:
  853. """Create a parallel config for use by the draft worker.
  854. This is mostly a copy of the target parallel config. In the future the
  855. draft worker can have a different parallel strategy, e.g. TP=1.
  856. """
  857. draft_parallel_config = ParallelConfig(
  858. pipeline_parallel_size=target_parallel_config.
  859. pipeline_parallel_size,
  860. tensor_parallel_size=target_parallel_config.tensor_parallel_size,
  861. distributed_executor_backend=target_parallel_config.
  862. distributed_executor_backend,
  863. max_parallel_loading_workers=target_parallel_config.
  864. max_parallel_loading_workers,
  865. disable_custom_all_reduce=target_parallel_config.
  866. disable_custom_all_reduce,
  867. tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
  868. ray_workers_use_nsight=target_parallel_config.
  869. ray_workers_use_nsight,
  870. placement_group=target_parallel_config.placement_group,
  871. )
  872. return draft_parallel_config
  873. def __init__(
  874. self,
  875. draft_model_config: ModelConfig,
  876. draft_parallel_config: ParallelConfig,
  877. num_speculative_tokens: int,
  878. speculative_disable_by_batch_size: Optional[int],
  879. ngram_prompt_lookup_max: Optional[int],
  880. ngram_prompt_lookup_min: Optional[int],
  881. ):
  882. """Create a SpeculativeConfig object.
  883. Args:
  884. draft_model_config: ModelConfig for the draft model.
  885. draft_parallel_config: ParallelConfig for the draft model.
  886. num_speculative_tokens: The number of tokens to sample from the
  887. draft model before scoring with the target model.
  888. speculative_disable_by_batch_size: Disable speculative
  889. decoding for new incoming requests when the number of
  890. enqueue requests is larger than this value.
  891. ngram_prompt_lookup_max: Max size of ngram token window.
  892. ngram_prompt_lookup_min: Min size of ngram token window.
  893. """
  894. self.draft_model_config = draft_model_config
  895. self.draft_parallel_config = draft_parallel_config
  896. self.num_speculative_tokens = num_speculative_tokens
  897. self.speculative_disable_by_batch_size = \
  898. speculative_disable_by_batch_size
  899. self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
  900. self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
  901. self._verify_args()
  902. def _verify_args(self) -> None:
  903. if self.num_speculative_tokens <= 0:
  904. raise ValueError("Expected num_speculative_tokens to be greater "
  905. f"than zero ({self.num_speculative_tokens}).")
  906. if self.draft_model_config:
  907. self.draft_model_config.verify_with_parallel_config(
  908. self.draft_parallel_config)
  909. @property
  910. def num_lookahead_slots(self) -> int:
  911. """The number of additional slots the scheduler should allocate per
  912. step, in addition to the slots allocated for each known token.
  913. This is equal to the number of speculative tokens, as each speculative
  914. token must be scored.
  915. """
  916. return self.num_speculative_tokens
  917. def __repr__(self) -> str:
  918. if self.ngram_prompt_lookup_max > 0:
  919. draft_model = "[ngram]"
  920. else:
  921. draft_model = self.draft_model_config.model
  922. num_spec_tokens = self.num_speculative_tokens
  923. return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
  924. @dataclass
  925. class LoRAConfig:
  926. max_lora_rank: int
  927. max_loras: int
  928. fully_sharded_loras: bool = False
  929. max_cpu_loras: Optional[int] = None
  930. lora_dtype: Optional[torch.dtype] = None
  931. lora_extra_vocab_size: int = 256
  932. # This is a constant.
  933. lora_vocab_padding_size: ClassVar[int] = 256
  934. long_lora_scaling_factors: Optional[Tuple[float]] = None
  935. def __post_init__(self):
  936. # Keep this in sync with kernels/punica/bgmv/bgmv_config.h
  937. possible_max_ranks = (8, 16, 32, 64)
  938. possible_lora_extra_vocab_size = (0, 256, 512)
  939. if self.max_lora_rank not in possible_max_ranks:
  940. raise ValueError(
  941. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  942. f"{possible_max_ranks}.")
  943. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  944. raise ValueError(
  945. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  946. f"must be one of {possible_lora_extra_vocab_size}.")
  947. if self.max_loras < 1:
  948. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  949. if self.max_cpu_loras is None:
  950. self.max_cpu_loras = self.max_loras
  951. elif self.max_cpu_loras < self.max_loras:
  952. raise ValueError(
  953. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  954. f"max_loras ({self.max_loras})")
  955. def verify_with_model_config(self, model_config: ModelConfig):
  956. if self.lora_dtype in (None, "auto"):
  957. self.lora_dtype = model_config.dtype
  958. elif isinstance(self.lora_dtype, str):
  959. self.lora_dtype = getattr(torch, self.lora_dtype)
  960. if model_config.quantization and model_config.quantization not in [
  961. "awq", "gptq"
  962. ]:
  963. # TODO support all other quants
  964. logger.warning(f"{model_config.quantization} quantization is not "
  965. "tested with LoRA yet.")
  966. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  967. if scheduler_config.max_num_batched_tokens > 65528:
  968. raise ValueError(
  969. "Due to limitations of the custom LoRA CUDA kernel, "
  970. "max_num_batched_tokens must be <= 65528 when "
  971. "LoRA is enabled.")
  972. @dataclass
  973. class VisionLanguageConfig:
  974. """Configs the input data format and how models should run for
  975. vision language models."""
  976. class ImageInputType(enum.Enum):
  977. """Image input type into the vision language model.
  978. An image roughly goes through the following transformation:
  979. Raw image --> pixel values --> image features --> image embeddings.
  980. The difference between different image input types is where the
  981. image encoder (pixel values --> image features) is run.
  982. Different image input types also correspond to different tensor shapes.
  983. For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
  984. IMAGE_FEATURES: (1, 576, 1024).
  985. """
  986. PIXEL_VALUES = enum.auto()
  987. IMAGE_FEATURES = enum.auto()
  988. image_input_type: ImageInputType
  989. # The input id corresponding to image token.
  990. image_token_id: int
  991. # Used for running `run_prefill_max_token`.
  992. # For models that support varying resolution, this corresponds to
  993. # worst case scenario (biggest supported resolution).
  994. image_input_shape: tuple
  995. image_feature_size: int
  996. @classmethod
  997. def get_image_input_enum_type(
  998. cls, value: str) -> "VisionLanguageConfig.ImageInputType":
  999. """Get the image input type from a string."""
  1000. try:
  1001. return cls.ImageInputType[value.upper()]
  1002. except KeyError as e:
  1003. raise ValueError(f"{value} is not a valid choice. "
  1004. f"Expecting to choose from "
  1005. f"{[x.name for x in cls.ImageInputType]}.") from e
  1006. _STR_DTYPE_TO_TORCH_DTYPE = {
  1007. "half": torch.float16,
  1008. "float16": torch.float16,
  1009. "float": torch.float32,
  1010. "float32": torch.float32,
  1011. "bfloat16": torch.bfloat16,
  1012. }
  1013. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  1014. def _get_and_verify_dtype(
  1015. config: PretrainedConfig,
  1016. dtype: Union[str, torch.dtype],
  1017. ) -> torch.dtype:
  1018. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  1019. # because config.torch_dtype can be None.
  1020. config_dtype = getattr(config, "torch_dtype", None)
  1021. if config_dtype is None:
  1022. config_dtype = torch.float32
  1023. if isinstance(dtype, str):
  1024. dtype = dtype.lower()
  1025. if dtype == "auto":
  1026. if config_dtype == torch.float32:
  1027. # Following the common practice, we use float16 for float32
  1028. # models.
  1029. torch_dtype = torch.float16
  1030. else:
  1031. torch_dtype = config_dtype
  1032. else:
  1033. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  1034. raise ValueError(f"Unknown dtype: {dtype}")
  1035. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  1036. elif isinstance(dtype, torch.dtype):
  1037. torch_dtype = dtype
  1038. else:
  1039. raise ValueError(f"Unknown dtype: {dtype}")
  1040. if is_hip() and torch_dtype == torch.float32:
  1041. rocm_supported_dtypes = [
  1042. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  1043. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  1044. ]
  1045. raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
  1046. f"Supported dtypes are {rocm_supported_dtypes}")
  1047. # Verify the dtype.
  1048. if torch_dtype != config_dtype:
  1049. if torch_dtype == torch.float32:
  1050. # Upcasting to float32 is allowed.
  1051. pass
  1052. elif config_dtype == torch.float32:
  1053. # Downcasting from float32 to float16 or bfloat16 is allowed.
  1054. pass
  1055. else:
  1056. # Casting between float16 and bfloat16 is allowed with a warning.
  1057. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  1058. return torch_dtype
  1059. def _get_and_verify_max_len(
  1060. hf_config: PretrainedConfig,
  1061. max_model_len: Optional[int],
  1062. ) -> int:
  1063. """Get and verify the model's maximum length."""
  1064. derived_max_model_len = float("inf")
  1065. possible_keys = [
  1066. # Cohere: needs to prioritize this over "max_position_embeddings"
  1067. "model_max_length",
  1068. # OPT
  1069. "max_position_embeddings",
  1070. # GPT-2
  1071. "n_positions",
  1072. # MPT
  1073. "max_seq_len",
  1074. # ChatGLM2
  1075. "seq_length",
  1076. # Command-R
  1077. "model_max_length",
  1078. # Others
  1079. "max_sequence_length",
  1080. "max_seq_length",
  1081. "seq_len",
  1082. ]
  1083. max_len_key = None
  1084. for key in possible_keys:
  1085. max_len = getattr(hf_config, key, None)
  1086. if max_len is not None:
  1087. max_len_key = key if max_len < derived_max_model_len \
  1088. else max_len_key
  1089. derived_max_model_len = min(derived_max_model_len, max_len)
  1090. if derived_max_model_len == float("inf"):
  1091. if max_model_len is not None:
  1092. # If max_model_len is specified, we use it.
  1093. return max_model_len
  1094. default_max_len = 2048
  1095. logger.warning(
  1096. "The model's config.json does not contain any of the following "
  1097. "keys to determine the original maximum length of the model: "
  1098. f"{possible_keys}. Assuming the model's maximum length is "
  1099. f"{default_max_len}.")
  1100. derived_max_model_len = default_max_len
  1101. rope_scaling = getattr(hf_config, "rope_scaling", None)
  1102. if rope_scaling is not None:
  1103. rope_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
  1104. if rope_type not in {"su", "longrope", "llama3"}:
  1105. assert "factor" in rope_scaling
  1106. scaling_factor = rope_scaling["factor"]
  1107. if rope_type == "yarn":
  1108. derived_max_model_len = rope_scaling[
  1109. "original_max_position_embeddings"]
  1110. derived_max_model_len *= scaling_factor
  1111. if max_model_len is None:
  1112. max_model_len = derived_max_model_len
  1113. elif max_model_len > derived_max_model_len:
  1114. # hope this works
  1115. scaling_factor = max_model_len / derived_max_model_len
  1116. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  1117. logger.warning(
  1118. f"User-specified max_model_len {max_model_len} is higher than "
  1119. f"the original {derived_max_model_len}. "
  1120. "Attempting to use RoPE scaling.")
  1121. derived_max_model_len = max_model_len
  1122. return int(max_model_len)
  1123. @dataclass
  1124. class DecodingConfig:
  1125. """Dataclass which contains the decoding strategy of the engine"""
  1126. # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
  1127. guided_decoding_backend: str = 'outlines'
  1128. def __post_init__(self):
  1129. valid_guided_backends = ['outlines', 'lm-format-enforcer']
  1130. backend = self.guided_decoding_backend
  1131. if backend not in valid_guided_backends:
  1132. raise ValueError(f"Invalid guided_decoding_backend '{backend},"
  1133. f"must be one of {valid_guided_backends}")
  1134. @dataclass(frozen=True)
  1135. class EngineConfig:
  1136. """Dataclass which contains all engine-related configuration. This
  1137. simplifies passing around the distinct configurations in the codebase.
  1138. """
  1139. model_config: ModelConfig
  1140. cache_config: CacheConfig
  1141. parallel_config: ParallelConfig
  1142. scheduler_config: SchedulerConfig
  1143. device_config: DeviceConfig
  1144. load_config: LoadConfig
  1145. lora_config: Optional[LoRAConfig]
  1146. vision_language_config: Optional[VisionLanguageConfig]
  1147. speculative_config: Optional[SpeculativeConfig]
  1148. decoding_config: Optional[DecodingConfig]
  1149. def __post_init__(self):
  1150. """Verify configs are valid & consistent with each other.
  1151. """
  1152. self.model_config.verify_with_parallel_config(self.parallel_config)
  1153. self.cache_config.verify_with_parallel_config(self.parallel_config)
  1154. if self.lora_config:
  1155. self.lora_config.verify_with_model_config(self.model_config)
  1156. self.lora_config.verify_with_scheduler_config(
  1157. self.scheduler_config)
  1158. def to_dict(self):
  1159. """Return the configs as a dictionary, for use in **kwargs.
  1160. """
  1161. return dict(
  1162. (field.name, getattr(self, field.name)) for field in fields(self))