args_tools.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115
  1. import argparse
  2. import dataclasses
  3. import json
  4. import os
  5. from dataclasses import dataclass
  6. from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
  7. Union)
  8. from loguru import logger
  9. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  10. EngineConfig, LoadConfig, LoRAConfig,
  11. ModelConfig, ParallelConfig,
  12. PromptAdapterConfig, SchedulerConfig,
  13. SpeculativeConfig, TokenizerPoolConfig)
  14. from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
  15. from aphrodite.executor.executor_base import ExecutorBase
  16. from aphrodite.quantization import QUANTIZATION_METHODS
  17. from aphrodite.transformers_utils.utils import check_gguf_file
  18. if TYPE_CHECKING:
  19. from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
  20. APHRODITE_USE_RAY_SPMD_WORKER = bool(
  21. os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))
  22. def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
  23. if len(val) == 0:
  24. return None
  25. out_dict: Dict[str, int] = {}
  26. for item in val.split(","):
  27. try:
  28. key, value = item.split("=")
  29. except TypeError as exc:
  30. msg = "Each item should be in the form KEY=VALUE"
  31. raise ValueError(msg) from exc
  32. try:
  33. out_dict[key] = int(value)
  34. except ValueError as exc:
  35. msg = f"Failed to parse value of item {key}={value}"
  36. raise ValueError(msg) from exc
  37. return out_dict
  38. @dataclass
  39. class EngineArgs:
  40. """Arguments for Aphrodite engine."""
  41. # Model Options
  42. model: str
  43. seed: int = 0
  44. served_model_name: Optional[Union[str, List[str]]] = None
  45. tokenizer: Optional[str] = None
  46. revision: Optional[str] = None
  47. code_revision: Optional[str] = None
  48. tokenizer_revision: Optional[str] = None
  49. tokenizer_mode: str = "auto"
  50. trust_remote_code: bool = False
  51. download_dir: Optional[str] = None
  52. max_model_len: Optional[int] = None
  53. max_context_len_to_capture: Optional[int] = None
  54. max_seq_len_to_capture: int = 8192
  55. rope_scaling: Optional[dict] = None
  56. rope_theta: Optional[float] = None
  57. model_loader_extra_config: Optional[dict] = None
  58. enforce_eager: Optional[bool] = None
  59. skip_tokenizer_init: bool = False
  60. tokenizer_pool_size: int = 0
  61. # Note: Specifying a tokenizer pool by passing a class
  62. # is intended for expert use only. The API may change without
  63. # notice.
  64. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
  65. tokenizer_pool_extra_config: Optional[dict] = None
  66. limit_mm_per_prompt: Optional[Mapping[str, int]] = None
  67. max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
  68. # Device Options
  69. device: str = "auto"
  70. # Load Options
  71. load_format: str = "auto"
  72. dtype: str = "auto"
  73. ignore_patterns: Optional[Union[str, List[str]]] = None
  74. # Parallel Options
  75. worker_use_ray: Optional[bool] = False
  76. tensor_parallel_size: int = 1
  77. pipeline_parallel_size: int = 1
  78. ray_workers_use_nsight: bool = False
  79. disable_custom_all_reduce: bool = False
  80. # Note: Specifying a custom executor backend by passing a class
  81. # is intended for expert use only. The API may change without
  82. # notice.
  83. distributed_executor_backend: Optional[Union[str,
  84. Type[ExecutorBase]]] = None
  85. max_parallel_loading_workers: Optional[int] = None
  86. # Quantization Options
  87. quantization: Optional[str] = None
  88. quantization_param_path: Optional[str] = None
  89. preemption_mode: Optional[str] = None
  90. deepspeed_fp_bits: Optional[int] = None
  91. # Cache Options
  92. kv_cache_dtype: str = "auto"
  93. block_size: int = 16
  94. enable_prefix_caching: Optional[bool] = False
  95. num_gpu_blocks_override: Optional[int] = None
  96. disable_sliding_window: bool = False
  97. gpu_memory_utilization: float = 0.90
  98. swap_space: float = 4 # GiB
  99. cpu_offload_gb: float = 0 # GiB
  100. # Scheduler Options
  101. use_v2_block_manager: bool = False
  102. scheduler_delay_factor: float = 0.0
  103. enable_chunked_prefill: bool = False
  104. guided_decoding_backend: str = 'outlines'
  105. max_num_batched_tokens: Optional[int] = None
  106. max_num_seqs: int = 256
  107. num_scheduler_steps: int = 1
  108. # Speculative Decoding Options
  109. num_lookahead_slots: int = 0
  110. speculative_model: Optional[str] = None
  111. speculative_model_quantization: Optional[str] = None
  112. num_speculative_tokens: Optional[int] = None
  113. speculative_max_model_len: Optional[int] = None
  114. ngram_prompt_lookup_max: Optional[int] = None
  115. ngram_prompt_lookup_min: Optional[int] = None
  116. speculative_draft_tensor_parallel_size: Optional[int] = None
  117. speculative_disable_by_batch_size: Optional[int] = None
  118. spec_decoding_acceptance_method: str = 'rejection_sampler'
  119. typical_acceptance_sampler_posterior_threshold: Optional[float] = None
  120. typical_acceptance_sampler_posterior_alpha: Optional[float] = None
  121. disable_logprobs_during_spec_decoding: Optional[bool] = None
  122. # Adapter Options
  123. enable_lora: bool = False
  124. max_loras: int = 1
  125. max_lora_rank: int = 16
  126. lora_extra_vocab_size: int = 256
  127. lora_dtype: str = "auto"
  128. max_cpu_loras: Optional[int] = None
  129. long_lora_scaling_factors: Optional[Tuple[float]] = None
  130. fully_sharded_loras: bool = False
  131. qlora_adapter_name_or_path: Optional[str] = None
  132. enable_prompt_adapter: bool = False
  133. max_prompt_adapters: int = 1
  134. max_prompt_adapter_token: int = 0
  135. # Log Options
  136. disable_log_stats: bool = False
  137. def __post_init__(self):
  138. if self.tokenizer is None:
  139. self.tokenizer = self.model
  140. if is_cpu():
  141. self.distributed_executor_backend = None
  142. @staticmethod
  143. def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
  144. """Shared CLI arguments for the Aphrodite engine."""
  145. # Model Options
  146. parser.add_argument(
  147. "--model",
  148. type=str,
  149. default="EleutherAI/pythia-70m-deduped",
  150. help="Category: Model Options\n"
  151. "name or path of the huggingface model to use",
  152. )
  153. parser.add_argument("--seed",
  154. type=int,
  155. default=EngineArgs.seed,
  156. help="Category: Model Options\n"
  157. "random seed")
  158. parser.add_argument(
  159. "--served-model-name",
  160. nargs="+",
  161. type=str,
  162. default=None,
  163. help="Category: API Options\n"
  164. "The model name(s) used in the API. If multiple "
  165. "names are provided, the server will respond to any "
  166. "of the provided names. The model name in the model "
  167. "field of a response will be the first name in this "
  168. "list. If not specified, the model name will be the "
  169. "same as the `--model` argument. Noted that this name(s)"
  170. "will also be used in `model_name` tag content of "
  171. "prometheus metrics, if multiple names provided, metrics"
  172. "tag will take the first one.")
  173. parser.add_argument(
  174. "--tokenizer",
  175. type=str,
  176. default=EngineArgs.tokenizer,
  177. help="Category: Model Options\n"
  178. "name or path of the huggingface tokenizer to use",
  179. )
  180. parser.add_argument(
  181. "--revision",
  182. type=str,
  183. default=None,
  184. help="Category: Model Options\n"
  185. "the specific model version to use. It can be a branch "
  186. "name, a tag name, or a commit id. If unspecified, will use "
  187. "the default version.",
  188. )
  189. parser.add_argument(
  190. "--code-revision",
  191. type=str,
  192. default=None,
  193. help="Category: Model Options\n"
  194. "the specific revision to use for the model code on "
  195. "Hugging Face Hub. It can be a branch name, a tag name, or a "
  196. "commit id. If unspecified, will use the default version.",
  197. )
  198. parser.add_argument(
  199. "--tokenizer-revision",
  200. type=str,
  201. default=None,
  202. help="Category: Model Options\n"
  203. "the specific tokenizer version to use. It can be a branch "
  204. "name, a tag name, or a commit id. If unspecified, will use "
  205. "the default version.",
  206. )
  207. parser.add_argument(
  208. "--tokenizer-mode",
  209. type=str,
  210. default=EngineArgs.tokenizer_mode,
  211. choices=['auto', 'slow', 'mistral'],
  212. help='The tokenizer mode.\n\n* "auto" will use the '
  213. 'fast tokenizer if available.\n* "slow" will '
  214. 'always use the slow tokenizer. \n* '
  215. '"mistral" will always use the `mistral_common` tokenizer.')
  216. parser.add_argument(
  217. "--trust-remote-code",
  218. action="store_true",
  219. help="Category: Model Options\n"
  220. "trust remote code from huggingface",
  221. )
  222. parser.add_argument(
  223. "--download-dir",
  224. type=str,
  225. default=EngineArgs.download_dir,
  226. help="Category: Model Options\n"
  227. "directory to download and load the weights, "
  228. "default to the default cache dir of "
  229. "huggingface",
  230. )
  231. parser.add_argument(
  232. "--max-model-len",
  233. type=int,
  234. default=EngineArgs.max_model_len,
  235. help="Category: Model Options\n"
  236. "model context length. If unspecified, "
  237. "will be automatically derived from the model.",
  238. )
  239. parser.add_argument("--max-context-len-to-capture",
  240. type=int,
  241. default=EngineArgs.max_context_len_to_capture,
  242. help="Category: Model Options\n"
  243. "Maximum context length covered by CUDA "
  244. "graphs. When a sequence has context length "
  245. "larger than this, we fall back to eager mode. "
  246. "(DEPRECATED. Use --max-seq_len-to-capture instead"
  247. ")")
  248. parser.add_argument("--max-seq_len-to-capture",
  249. type=int,
  250. default=EngineArgs.max_seq_len_to_capture,
  251. help="Category: Model Options\n"
  252. "Maximum sequence length covered by CUDA "
  253. "graphs. When a sequence has context length "
  254. "larger than this, we fall back to eager mode.")
  255. parser.add_argument('--rope-scaling',
  256. default=None,
  257. type=json.loads,
  258. help='Category: Model Options\n'
  259. 'RoPE scaling configuration in JSON format. '
  260. 'For example, {"type":"dynamic","factor":2.0}')
  261. parser.add_argument('--rope-theta',
  262. default=None,
  263. type=float,
  264. help='Category: Model Options\n'
  265. 'RoPE theta. Use with `rope_scaling`. In '
  266. 'some cases, changing the RoPE theta improves the '
  267. 'performance of the scaled model.')
  268. parser.add_argument("--model-loader-extra-config",
  269. type=str,
  270. default=EngineArgs.model_loader_extra_config,
  271. help="Category: Model Options\n"
  272. "Extra config for model loader. "
  273. "This will be passed to the model loader "
  274. "corresponding to the chosen load_format. "
  275. "This should be a JSON string that will be "
  276. "parsed into a dictionary.")
  277. parser.add_argument(
  278. "--enforce-eager",
  279. action=StoreBoolean,
  280. default=EngineArgs.enforce_eager,
  281. nargs="?",
  282. const="True",
  283. help="Category: Model Options\n"
  284. "Always use eager-mode PyTorch. If False, "
  285. "will use eager mode and CUDA graph in hybrid "
  286. "for maximal performance and flexibility.",
  287. )
  288. parser.add_argument("--skip-tokenizer-init",
  289. action="store_true",
  290. help="Category: Model Options\n"
  291. "Skip initialization of tokenizer and detokenizer")
  292. parser.add_argument("--tokenizer-pool-size",
  293. type=int,
  294. default=EngineArgs.tokenizer_pool_size,
  295. help="Category: Model Options\n"
  296. "Size of tokenizer pool to use for "
  297. "asynchronous tokenization. If 0, will "
  298. "use synchronous tokenization.")
  299. parser.add_argument("--tokenizer-pool-type",
  300. type=str,
  301. default=EngineArgs.tokenizer_pool_type,
  302. help="Category: Model Options\n"
  303. "The type of tokenizer pool to use for "
  304. "asynchronous tokenization. Ignored if "
  305. "tokenizer_pool_size is 0.")
  306. parser.add_argument("--tokenizer-pool-extra-config",
  307. type=str,
  308. default=EngineArgs.tokenizer_pool_extra_config,
  309. help="Category: Model Options\n"
  310. "Extra config for tokenizer pool. "
  311. "This should be a JSON string that will be "
  312. "parsed into a dictionary. Ignored if "
  313. "tokenizer_pool_size is 0.")
  314. # Multimodal related configs
  315. parser.add_argument(
  316. '--limit-mm-per-prompt',
  317. type=nullable_kvs,
  318. default=EngineArgs.limit_mm_per_prompt,
  319. # The default value is given in
  320. # MultiModalRegistry.init_mm_limits_per_prompt
  321. help=('For each multimodal plugin, limit how many '
  322. 'input instances to allow for each prompt. '
  323. 'Expects a comma-separated list of items, '
  324. 'e.g.: `image=16,video=2` allows a maximum of 16 '
  325. 'images and 2 videos per prompt. Defaults to 1 for '
  326. 'each modality.'))
  327. parser.add_argument(
  328. "--max-logprobs",
  329. type=int,
  330. default=EngineArgs.max_logprobs,
  331. help="Category: Model Options\n"
  332. "maximum number of log probabilities to "
  333. "return.",
  334. )
  335. # Device Options
  336. parser.add_argument(
  337. "--device",
  338. type=str,
  339. default=EngineArgs.device,
  340. choices=[
  341. "auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"
  342. ],
  343. help=("Category: Model Options\n"
  344. "Device to use for model execution."),
  345. )
  346. # Load Options
  347. parser.add_argument(
  348. '--load-format',
  349. type=str,
  350. default=EngineArgs.load_format,
  351. choices=[
  352. 'auto',
  353. 'pt',
  354. 'safetensors',
  355. 'npcache',
  356. 'dummy',
  357. 'tensorizer',
  358. 'sharded_state',
  359. 'bitsandbytes',
  360. ],
  361. help='Category: Model Options\n'
  362. 'The format of the model weights to load.\n\n'
  363. '* "auto" will try to load the weights in the safetensors format '
  364. 'and fall back to the pytorch bin format if safetensors format '
  365. 'is not available.\n'
  366. '* "pt" will load the weights in the pytorch bin format.\n'
  367. '* "safetensors" will load the weights in the safetensors format.\n'
  368. '* "npcache" will load the weights in pytorch format and store '
  369. 'a numpy cache to speed up the loading.\n'
  370. '* "dummy" will initialize the weights with random values, '
  371. 'which is mainly for profiling.\n'
  372. '* "tensorizer" will load the weights using tensorizer from '
  373. 'CoreWeave. See the Tensorize Aphrodite Model script in the '
  374. 'Examples section for more information.\n'
  375. '* "bitsandbytes" will load the weights using bitsandbytes '
  376. 'quantization.\n')
  377. parser.add_argument(
  378. '--dtype',
  379. type=str,
  380. default=EngineArgs.dtype,
  381. choices=[
  382. 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
  383. ],
  384. help='Category: Model Options\n'
  385. 'Data type for model weights and activations.\n\n'
  386. '* "auto" will use FP16 precision for FP32 and FP16 models, and '
  387. 'BF16 precision for BF16 models.\n'
  388. '* "half" for FP16. Recommended for AWQ quantization.\n'
  389. '* "float16" is the same as "half".\n'
  390. '* "bfloat16" for a balance between precision and range.\n'
  391. '* "float" is shorthand for FP32 precision.\n'
  392. '* "float32" for FP32 precision.')
  393. parser.add_argument(
  394. '--ignore-patterns',
  395. action="append",
  396. type=str,
  397. default=[],
  398. help="Category: Model Options\n"
  399. "The pattern(s) to ignore when loading the model."
  400. "Defaults to 'original/**/*' to avoid repeated loading of llama's "
  401. "checkpoints.")
  402. # Parallel Options
  403. parser.add_argument(
  404. '--worker-use-ray',
  405. action='store_true',
  406. help='Category: Parallel Options\n'
  407. 'Deprecated, use --distributed-executor-backend=ray.')
  408. parser.add_argument(
  409. "--tensor-parallel-size",
  410. "-tp",
  411. type=int,
  412. default=EngineArgs.tensor_parallel_size,
  413. help="Category: Parallel Options\n"
  414. "number of tensor parallel replicas, i.e. the number of GPUs "
  415. "to use.")
  416. parser.add_argument(
  417. "--pipeline-parallel-size",
  418. "-pp",
  419. type=int,
  420. default=EngineArgs.pipeline_parallel_size,
  421. help="Category: Parallel Options\n"
  422. "number of pipeline stages. Currently not supported.")
  423. parser.add_argument(
  424. "--ray-workers-use-nsight",
  425. action="store_true",
  426. help="Category: Parallel Options\n"
  427. "If specified, use nsight to profile ray workers",
  428. )
  429. parser.add_argument(
  430. "--disable-custom-all-reduce",
  431. action="store_true",
  432. default=EngineArgs.disable_custom_all_reduce,
  433. help="Category: Model Options\n"
  434. "See ParallelConfig",
  435. )
  436. parser.add_argument(
  437. '--distributed-executor-backend',
  438. choices=['ray', 'mp'],
  439. default=EngineArgs.distributed_executor_backend,
  440. help='Category: Parallel Options\n'
  441. 'Backend to use for distributed serving. When more than 1 GPU '
  442. 'is used, will be automatically set to "ray" if installed '
  443. 'or "mp" (multiprocessing) otherwise.')
  444. parser.add_argument(
  445. "--max-parallel-loading-workers",
  446. type=int,
  447. default=EngineArgs.max_parallel_loading_workers,
  448. help="Category: Parallel Options\n"
  449. "load model sequentially in multiple batches, "
  450. "to avoid RAM OOM when using tensor "
  451. "parallel and large models",
  452. )
  453. # Quantization Options
  454. parser.add_argument(
  455. "--quantization",
  456. "-q",
  457. type=str,
  458. choices=[*QUANTIZATION_METHODS, None],
  459. default=EngineArgs.quantization,
  460. help="Category: Quantization Options\n"
  461. "Method used to quantize the weights. If "
  462. "None, we first check the `quantization_config` "
  463. "attribute in the model config file. If that is "
  464. "None, we assume the model weights are not "
  465. "quantized and use `dtype` to determine the data "
  466. "type of the weights.",
  467. )
  468. parser.add_argument(
  469. '--quantization-param-path',
  470. type=str,
  471. default=None,
  472. help='Category: Quantization Options\n'
  473. 'Path to the JSON file containing the KV cache '
  474. 'scaling factors. This should generally be supplied, when '
  475. 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
  476. 'default to 1.0, which may cause accuracy issues. '
  477. 'FP8_E5M2 (without scaling) is only supported on cuda version'
  478. 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
  479. 'supported for common inference criteria. ')
  480. parser.add_argument(
  481. '--preemption-mode',
  482. type=str,
  483. default=None,
  484. help='Category: Scheduler Options\n'
  485. 'If \'recompute\', the engine performs preemption by block '
  486. 'swapping; If \'swap\', the engine performs preemption by block '
  487. 'swapping.')
  488. parser.add_argument("--deepspeed-fp-bits",
  489. type=int,
  490. default=None,
  491. help="Category: Quantization Options\n"
  492. "Number of floating bits to use for the deepseed "
  493. "quantization. Supported bits are: 4, 6, 8, 12. ")
  494. # Cache Options
  495. parser.add_argument(
  496. '--kv-cache-dtype',
  497. type=str,
  498. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  499. default=EngineArgs.kv_cache_dtype,
  500. help='Category: Cache Options\n'
  501. 'Data type for kv cache storage. If "auto", will use model '
  502. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  503. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  504. parser.add_argument(
  505. "--block-size",
  506. type=int,
  507. default=EngineArgs.block_size,
  508. choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
  509. help="Category: Cache Options\n"
  510. "token block size",
  511. )
  512. parser.add_argument(
  513. "--enable-prefix-caching",
  514. "--context-shift",
  515. action="store_true",
  516. help="Category: Cache Options\n"
  517. "Enable automatic prefix caching.",
  518. )
  519. parser.add_argument(
  520. "--num-gpu-blocks-override",
  521. type=int,
  522. default=None,
  523. help="Category: Cache Options Options\n"
  524. "If specified, ignore GPU profiling result and use this "
  525. "number of GPU blocks. Used for testing preemption.")
  526. parser.add_argument('--disable-sliding-window',
  527. action='store_true',
  528. help='Category: KV Cache Options\n'
  529. 'Disables sliding window, '
  530. 'capping to sliding window size')
  531. parser.add_argument(
  532. "--gpu-memory-utilization",
  533. "-gmu",
  534. type=float,
  535. default=EngineArgs.gpu_memory_utilization,
  536. help="Category: Cache Options\n"
  537. "The fraction of GPU memory to be used for "
  538. "the model executor, which can range from 0 to 1."
  539. "If unspecified, will use the default value of 0.9.",
  540. )
  541. parser.add_argument(
  542. "--swap-space",
  543. type=float,
  544. default=EngineArgs.swap_space,
  545. help="Category: Cache Options\n"
  546. "CPU swap space size (GiB) per GPU",
  547. )
  548. parser.add_argument(
  549. '--cpu-offload-gb',
  550. type=float,
  551. default=0,
  552. help='Category: Cache Options\n'
  553. 'The space in GiB to offload to CPU, per GPU. '
  554. 'Default is 0, which means no offloading. Intuitively, '
  555. 'this argument can be seen as a virtual way to increase '
  556. 'the GPU memory size. For example, if you have one 24 GB '
  557. 'GPU and set this to 10, virtually you can think of it as '
  558. 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
  559. 'which requires at least 26GB GPU memory. Note that this '
  560. 'requires fast CPU-GPU interconnect, as part of the model is'
  561. 'loaded from CPU memory to GPU memory on the fly in each '
  562. 'model forward pass.')
  563. # Scheduler Options
  564. parser.add_argument("--use-v2-block-manager",
  565. action="store_true",
  566. help="Category: Scheduler Options\n"
  567. "Use the v2 block manager.")
  568. parser.add_argument(
  569. "--scheduler-delay-factor",
  570. "-sdf",
  571. type=float,
  572. default=EngineArgs.scheduler_delay_factor,
  573. help="Category: Scheduler Options\n"
  574. "Apply a delay (of delay factor multiplied by previous "
  575. "prompt latency) before scheduling next prompt.")
  576. parser.add_argument(
  577. "--enable-chunked-prefill",
  578. action=StoreBoolean,
  579. default=EngineArgs.enable_chunked_prefill,
  580. nargs="?",
  581. const="True",
  582. help="Category: Scheduler Options\n"
  583. "If True, the prefill requests can be chunked based on the "
  584. "max_num_batched_tokens.")
  585. parser.add_argument(
  586. '--guided-decoding-backend',
  587. type=str,
  588. default='outlines',
  589. choices=['outlines', 'lm-format-enforcer'],
  590. help='Category: Scheduler Options\n'
  591. 'Which engine will be used for guided decoding'
  592. ' (JSON schema / regex etc) by default. Currently support '
  593. 'https://github.com/outlines-dev/outlines and '
  594. 'https://github.com/noamgat/lm-format-enforcer.'
  595. ' Can be overridden per request via guided_decoding_backend'
  596. ' parameter.')
  597. parser.add_argument(
  598. "--max-num-batched-tokens",
  599. type=int,
  600. default=EngineArgs.max_num_batched_tokens,
  601. help="Category: KV Cache Options\n"
  602. "maximum number of batched tokens per "
  603. "iteration",
  604. )
  605. parser.add_argument(
  606. "--max-num-seqs",
  607. type=int,
  608. default=EngineArgs.max_num_seqs,
  609. help="Category: API Options\n"
  610. "maximum number of sequences per iteration",
  611. )
  612. parser.add_argument('--num-scheduler-steps',
  613. type=int,
  614. default=1,
  615. help=('Maximum number of forward steps per '
  616. 'scheduler call.'))
  617. # Speculative Decoding Options
  618. parser.add_argument("--num-lookahead-slots",
  619. type=int,
  620. default=EngineArgs.num_lookahead_slots,
  621. help="Category: Speculative Decoding Options\n"
  622. "Experimental scheduling config necessary for "
  623. "speculative decoding. This will be replaced by "
  624. "speculative decoding config in the future; it is "
  625. "present for testing purposes until then.")
  626. parser.add_argument(
  627. "--speculative-model",
  628. type=str,
  629. default=EngineArgs.speculative_model,
  630. help="Category: Speculative Decoding Options\n"
  631. "The name of the draft model to be used in speculative decoding.")
  632. # Quantization settings for speculative model.
  633. parser.add_argument(
  634. '--speculative-model-quantization',
  635. type=str,
  636. choices=[*QUANTIZATION_METHODS, None],
  637. default=EngineArgs.speculative_model_quantization,
  638. help='Method used to quantize the weights of speculative model.'
  639. 'If None, we first check the `quantization_config` '
  640. 'attribute in the model config file. If that is '
  641. 'None, we assume the model weights are not '
  642. 'quantized and use `dtype` to determine the data '
  643. 'type of the weights.')
  644. parser.add_argument("--num-speculative-tokens",
  645. type=int,
  646. default=EngineArgs.num_speculative_tokens,
  647. help="Category: Speculative Decoding Options\n"
  648. "The number of speculative tokens to sample from "
  649. "the draft model in speculative decoding")
  650. parser.add_argument(
  651. "--speculative-max-model-len",
  652. type=str,
  653. default=EngineArgs.speculative_max_model_len,
  654. help="Category: Speculative Decoding Options\n"
  655. "The maximum sequence length supported by the "
  656. "draft model. Sequences over this length will skip "
  657. "speculation.")
  658. parser.add_argument(
  659. "--ngram-prompt-lookup-max",
  660. type=int,
  661. default=EngineArgs.ngram_prompt_lookup_max,
  662. help="Category: Speculative Decoding Options\n"
  663. "Max size of window for ngram prompt lookup in speculative "
  664. "decoding.")
  665. parser.add_argument(
  666. "--ngram-prompt-lookup-min",
  667. type=int,
  668. default=EngineArgs.ngram_prompt_lookup_min,
  669. help="Category: Speculative Decoding Options\n"
  670. "Min size of window for ngram prompt lookup in speculative "
  671. "decoding.")
  672. parser.add_argument(
  673. "--speculative-draft-tensor-parallel-size",
  674. "-spec-draft-tp",
  675. type=int,
  676. default=EngineArgs.speculative_draft_tensor_parallel_size,
  677. help="Category: Speculative Decoding Options\n"
  678. "Number of tensor parallel replicas for "
  679. "the draft model in speculative decoding.")
  680. parser.add_argument(
  681. "--speculative-disable-by-batch-size",
  682. type=int,
  683. default=EngineArgs.speculative_disable_by_batch_size,
  684. help="Category: Speculative Decoding Options\n"
  685. "Disable speculative decoding for new incoming requests "
  686. "if the number of enqueue requests is larger than this value.")
  687. parser.add_argument(
  688. '--spec-decoding-acceptance-method',
  689. type=str,
  690. default=EngineArgs.spec_decoding_acceptance_method,
  691. choices=['rejection_sampler', 'typical_acceptance_sampler'],
  692. help='Category: Speculative Decoding Options\n'
  693. 'Specify the acceptance method to use during draft token '
  694. 'verification in speculative decoding. Two types of acceptance '
  695. 'routines are supported: '
  696. '1) RejectionSampler which does not allow changing the '
  697. 'acceptance rate of draft tokens, '
  698. '2) TypicalAcceptanceSampler which is configurable, allowing for '
  699. 'a higher acceptance rate at the cost of lower quality, '
  700. 'and vice versa.')
  701. parser.add_argument(
  702. '--typical-acceptance-sampler-posterior-threshold',
  703. type=float,
  704. default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
  705. help='Category: Speculative Decoding Options\n'
  706. 'Set the lower bound threshold for the posterior '
  707. 'probability of a token to be accepted. This threshold is '
  708. 'used by the TypicalAcceptanceSampler to make sampling decisions '
  709. 'during speculative decoding. Defaults to 0.09')
  710. parser.add_argument(
  711. '--typical-acceptance-sampler-posterior-alpha',
  712. type=float,
  713. default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
  714. help='Category: Speculative Decoding Options\n'
  715. 'A scaling factor for the entropy-based threshold for token '
  716. 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
  717. 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
  718. 'i.e. 0.3')
  719. parser.add_argument(
  720. '--disable-logprobs-during-spec-decoding',
  721. type=bool,
  722. default=EngineArgs.disable_logprobs_during_spec_decoding,
  723. help='Category: Speculative Decoding Options\n'
  724. 'If set to True, token log probabilities are not returned '
  725. 'during speculative decoding. If set to False, log probabilities '
  726. 'are returned according to the settings in SamplingParams. If '
  727. 'not specified, it defaults to True. Disabling log probabilities '
  728. 'during speculative decoding reduces latency by skipping logprob '
  729. 'calculation in proposal sampling, target sampling, and after '
  730. 'accepted tokens are determined.')
  731. # Adapter Options
  732. parser.add_argument(
  733. "--enable-lora",
  734. action="store_true",
  735. help="Category: Adapter Options\n"
  736. "If True, enable handling of LoRA adapters.",
  737. )
  738. parser.add_argument(
  739. "--max-loras",
  740. type=int,
  741. default=EngineArgs.max_loras,
  742. help="Category: Adapter Options\n"
  743. "Max number of LoRAs in a single batch.",
  744. )
  745. parser.add_argument(
  746. "--max-lora-rank",
  747. type=int,
  748. default=EngineArgs.max_lora_rank,
  749. help="Category: Adapter Options\n"
  750. "Max LoRA rank.",
  751. )
  752. parser.add_argument(
  753. "--lora-extra-vocab-size",
  754. type=int,
  755. default=EngineArgs.lora_extra_vocab_size,
  756. help=("Category: Adapter Options\n"
  757. "Maximum size of extra vocabulary that can be "
  758. "present in a LoRA adapter (added to the base "
  759. "model vocabulary)."),
  760. )
  761. parser.add_argument(
  762. "--lora-dtype",
  763. type=str,
  764. default=EngineArgs.lora_dtype,
  765. choices=["auto", "float16", "bfloat16", "float32"],
  766. help=("Category: Adapter Options\n"
  767. "Data type for LoRA. If auto, will default to "
  768. "base model dtype."),
  769. )
  770. parser.add_argument(
  771. "--max-cpu-loras",
  772. type=int,
  773. default=EngineArgs.max_cpu_loras,
  774. help=("Category: Adapter Options\n"
  775. "Maximum number of LoRAs to store in CPU memory. "
  776. "Must be >= than max_num_seqs. "
  777. "Defaults to max_num_seqs."),
  778. )
  779. parser.add_argument(
  780. "--long-lora-scaling-factors",
  781. type=str,
  782. default=EngineArgs.long_lora_scaling_factors,
  783. help=("Category: Adapter Options\n"
  784. "Specify multiple scaling factors (which can "
  785. "be different from base model scaling factor "
  786. "- see eg. Long LoRA) to allow for multiple "
  787. "LoRA adapters trained with those scaling "
  788. "factors to be used at the same time. If not "
  789. "specified, only adapters trained with the "
  790. "base model scaling factor are allowed."))
  791. parser.add_argument(
  792. "--fully-sharded-loras",
  793. action='store_true',
  794. help=("Category: Adapter Options\n"
  795. "By default, only half of the LoRA computation is sharded "
  796. "with tensor parallelism. Enabling this will use the fully "
  797. "sharded layers. At high sequence length, max rank or "
  798. "tensor parallel size, this is likely faster."))
  799. parser.add_argument("--qlora-adapter-name-or-path",
  800. type=str,
  801. default=None,
  802. help="Category: Adapter Options\n"
  803. "Name or path of the LoRA adapter to use.")
  804. parser.add_argument('--enable-prompt-adapter',
  805. action='store_true',
  806. help='Category: Adapter Options\n'
  807. 'If True, enable handling of PromptAdapters.')
  808. parser.add_argument('--max-prompt-adapters',
  809. type=int,
  810. default=EngineArgs.max_prompt_adapters,
  811. help='Category: Adapter Options\n'
  812. 'Max number of PromptAdapters in a batch.')
  813. parser.add_argument('--max-prompt-adapter-token',
  814. type=int,
  815. default=EngineArgs.max_prompt_adapter_token,
  816. help='Category: Adapter Options\n'
  817. 'Max number of PromptAdapters tokens')
  818. # Log Options
  819. parser.add_argument(
  820. "--disable-log-stats",
  821. action="store_true",
  822. help="Category: Log Options\n"
  823. "disable logging statistics",
  824. )
  825. return parser
  826. @classmethod
  827. def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
  828. # Get the list of attributes of this dataclass.
  829. attrs = [attr.name for attr in dataclasses.fields(cls)]
  830. # Set the attributes from the parsed arguments.
  831. engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
  832. return engine_args
  833. def create_engine_config(self, ) -> EngineConfig:
  834. # gguf file needs a specific model loader and doesn't use hf_repo
  835. if check_gguf_file(self.model):
  836. self.quantization = self.load_format = "gguf"
  837. # bitsandbytes quantization needs a specific model loader
  838. # so we make sure the quant method and the load format are consistent
  839. if (self.quantization == "bitsandbytes" or
  840. self.qlora_adapter_name_or_path is not None) and \
  841. self.load_format != "bitsandbytes":
  842. raise ValueError(
  843. "BitsAndBytes quantization and QLoRA adapter only support "
  844. f"'bitsandbytes' load format, but got {self.load_format}")
  845. if (self.load_format == "bitsandbytes" or
  846. self.qlora_adapter_name_or_path is not None) and \
  847. self.quantization != "bitsandbytes":
  848. raise ValueError(
  849. "BitsAndBytes load format and QLoRA adapter only support "
  850. f"'bitsandbytes' quantization, but got {self.quantization}")
  851. assert self.cpu_offload_gb >= 0, (
  852. "CPU offload space must be non-negative"
  853. f", but got {self.cpu_offload_gb}")
  854. device_config = DeviceConfig(device=self.device)
  855. model_config = ModelConfig(
  856. model=self.model,
  857. tokenizer=self.tokenizer,
  858. tokenizer_mode=self.tokenizer_mode,
  859. trust_remote_code=self.trust_remote_code,
  860. dtype=self.dtype,
  861. seed=self.seed,
  862. revision=self.revision,
  863. code_revision=self.code_revision,
  864. rope_scaling=self.rope_scaling,
  865. rope_theta=self.rope_theta,
  866. tokenizer_revision=self.tokenizer_revision,
  867. max_model_len=self.max_model_len,
  868. quantization=self.quantization,
  869. deepspeed_fp_bits=self.deepspeed_fp_bits,
  870. quantization_param_path=self.quantization_param_path,
  871. enforce_eager=self.enforce_eager,
  872. max_context_len_to_capture=self.max_context_len_to_capture,
  873. max_seq_len_to_capture=self.max_seq_len_to_capture,
  874. max_logprobs=self.max_logprobs,
  875. disable_sliding_window=self.disable_sliding_window,
  876. skip_tokenizer_init=self.skip_tokenizer_init,
  877. served_model_name=self.served_model_name,
  878. limit_mm_per_prompt=self.limit_mm_per_prompt,
  879. )
  880. cache_config = CacheConfig(
  881. block_size=self.block_size,
  882. gpu_memory_utilization=self.gpu_memory_utilization,
  883. swap_space=self.swap_space,
  884. cache_dtype=self.kv_cache_dtype,
  885. is_attention_free=model_config.is_attention_free(),
  886. num_gpu_blocks_override=self.num_gpu_blocks_override,
  887. sliding_window=model_config.get_sliding_window(),
  888. enable_prefix_caching=self.enable_prefix_caching,
  889. cpu_offload_gb=self.cpu_offload_gb,
  890. )
  891. parallel_config = ParallelConfig(
  892. pipeline_parallel_size=self.pipeline_parallel_size,
  893. tensor_parallel_size=self.tensor_parallel_size,
  894. worker_use_ray=self.worker_use_ray,
  895. max_parallel_loading_workers=self.max_parallel_loading_workers,
  896. disable_custom_all_reduce=self.disable_custom_all_reduce,
  897. tokenizer_pool_config=TokenizerPoolConfig.create_config(
  898. tokenizer_pool_size=self.tokenizer_pool_size,
  899. tokenizer_pool_type=self.tokenizer_pool_type,
  900. tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
  901. ),
  902. ray_workers_use_nsight=self.ray_workers_use_nsight,
  903. distributed_executor_backend=self.distributed_executor_backend)
  904. max_model_len = model_config.max_model_len
  905. use_long_context = max_model_len > 32768
  906. if self.enable_chunked_prefill is None:
  907. # If not explicitly set, enable chunked prefill by default for
  908. # long context (> 32K) models. This is to avoid OOM errors in the
  909. # initial memory profiling phase.
  910. if use_long_context:
  911. is_gpu = device_config.device_type == "cuda"
  912. use_sliding_window = (model_config.get_sliding_window()
  913. is not None)
  914. use_spec_decode = self.speculative_model is not None
  915. has_seqlen_agnostic_layers = (
  916. model_config.contains_seqlen_agnostic_layers(
  917. parallel_config))
  918. if (is_gpu and not use_sliding_window and not use_spec_decode
  919. and not self.enable_lora
  920. and not self.enable_prompt_adapter
  921. and not self.enable_prefix_caching
  922. and not has_seqlen_agnostic_layers):
  923. self.enable_chunked_prefill = True
  924. logger.warning(
  925. "Chunked prefill is enabled by default for models with "
  926. "max_model_len > 32K. Currently, chunked prefill might "
  927. "not work with some features or models. If you "
  928. "encounter any issues, please disable chunked prefill "
  929. "by setting --enable-chunked-prefill=False.")
  930. if self.enable_chunked_prefill is None:
  931. self.enable_chunked_prefill = False
  932. if not self.enable_chunked_prefill and use_long_context:
  933. logger.warning(
  934. f"The model has a long context length ({max_model_len}). "
  935. "This may cause OOM errors during the initial memory "
  936. "profiling phase, or result in low performance due to small "
  937. "KV cache space. Consider setting --max-model-len to a "
  938. "smaller value.")
  939. speculative_config = SpeculativeConfig.maybe_create_spec_config(
  940. target_model_config=model_config,
  941. target_parallel_config=parallel_config,
  942. target_dtype=self.dtype,
  943. speculative_model=self.speculative_model,
  944. speculative_model_quantization = \
  945. self.speculative_model_quantization,
  946. speculative_draft_tensor_parallel_size=self.
  947. speculative_draft_tensor_parallel_size,
  948. num_speculative_tokens=self.num_speculative_tokens,
  949. speculative_disable_by_batch_size=self.
  950. speculative_disable_by_batch_size,
  951. speculative_max_model_len=self.speculative_max_model_len,
  952. enable_chunked_prefill=self.enable_chunked_prefill,
  953. use_v2_block_manager=self.use_v2_block_manager,
  954. disable_log_stats=self.disable_log_stats,
  955. ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
  956. ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
  957. draft_token_acceptance_method=\
  958. self.spec_decoding_acceptance_method,
  959. typical_acceptance_sampler_posterior_threshold=self.
  960. typical_acceptance_sampler_posterior_threshold,
  961. typical_acceptance_sampler_posterior_alpha=self.
  962. typical_acceptance_sampler_posterior_alpha,
  963. disable_logprobs=self.disable_logprobs_during_spec_decoding,
  964. )
  965. if self.num_scheduler_steps > 1:
  966. raise NotImplementedError("Multi-step is not yet supported.")
  967. if speculative_config is not None:
  968. raise ValueError("Speculative decoding is not supported with "
  969. "multi-step (--num-scheduler-steps > 1)")
  970. if self.enable_chunked_prefill:
  971. raise ValueError("Chunked prefill is not supported with "
  972. "multi-step (--num-scheduler-steps > 1)")
  973. # make sure num_lookahead_slots is set the higher value depending on
  974. # if we are using speculative decoding or multi-step
  975. num_lookahead_slots = max(self.num_lookahead_slots,
  976. self.num_scheduler_steps - 1)
  977. num_lookahead_slots = num_lookahead_slots \
  978. if speculative_config is None \
  979. else speculative_config.num_lookahead_slots
  980. scheduler_config = SchedulerConfig(
  981. max_num_batched_tokens=self.max_num_batched_tokens,
  982. max_num_seqs=self.max_num_seqs,
  983. max_model_len=model_config.max_model_len,
  984. is_attention_free=model_config.is_attention_free(),
  985. use_v2_block_manager=self.use_v2_block_manager,
  986. num_lookahead_slots=num_lookahead_slots,
  987. delay_factor=self.scheduler_delay_factor,
  988. enable_chunked_prefill=self.enable_chunked_prefill,
  989. embedding_mode=model_config.embedding_mode,
  990. preemption_mode=self.preemption_mode,
  991. num_scheduler_steps=self.num_scheduler_steps,
  992. send_delta_data=(APHRODITE_USE_RAY_SPMD_WORKER
  993. and parallel_config.use_ray),
  994. )
  995. lora_config = LoRAConfig(
  996. max_lora_rank=self.max_lora_rank,
  997. max_loras=self.max_loras,
  998. fully_sharded_loras=self.fully_sharded_loras,
  999. lora_extra_vocab_size=self.lora_extra_vocab_size,
  1000. long_lora_scaling_factors=self.long_lora_scaling_factors,
  1001. lora_dtype=self.lora_dtype,
  1002. max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
  1003. and self.max_cpu_loras > 0 else None) if self.enable_lora else None
  1004. if self.qlora_adapter_name_or_path is not None and \
  1005. self.qlora_adapter_name_or_path != "":
  1006. if self.model_loader_extra_config is None:
  1007. self.model_loader_extra_config = {}
  1008. self.model_loader_extra_config[
  1009. "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
  1010. load_config = LoadConfig(
  1011. load_format=self.load_format,
  1012. download_dir=self.download_dir,
  1013. model_loader_extra_config=self.model_loader_extra_config,
  1014. ignore_patterns=self.ignore_patterns)
  1015. prompt_adapter_config = PromptAdapterConfig(
  1016. max_prompt_adapters=self.max_prompt_adapters,
  1017. max_prompt_adapter_token=self.max_prompt_adapter_token) \
  1018. if self.enable_prompt_adapter else None
  1019. decoding_config = DecodingConfig(
  1020. guided_decoding_backend=self.guided_decoding_backend)
  1021. if (model_config.get_sliding_window() is not None
  1022. and scheduler_config.chunked_prefill_enabled
  1023. and not scheduler_config.use_v2_block_manager):
  1024. raise ValueError(
  1025. "Chunked prefill is not supported with sliding window. "
  1026. "Set --disable-sliding-window to disable sliding window.")
  1027. return EngineConfig(model_config=model_config,
  1028. cache_config=cache_config,
  1029. parallel_config=parallel_config,
  1030. scheduler_config=scheduler_config,
  1031. device_config=device_config,
  1032. lora_config=lora_config,
  1033. speculative_config=speculative_config,
  1034. load_config=load_config,
  1035. decoding_config=decoding_config,
  1036. prompt_adapter_config=prompt_adapter_config)
  1037. @dataclass
  1038. class AsyncEngineArgs(EngineArgs):
  1039. """Arguments for asynchronous Aphrodite engine."""
  1040. engine_use_ray: bool = False
  1041. disable_log_requests: bool = False
  1042. uvloop: bool = False
  1043. @staticmethod
  1044. def add_cli_args(parser: FlexibleArgumentParser,
  1045. async_args_only: bool = False) -> FlexibleArgumentParser:
  1046. if not async_args_only:
  1047. parser = EngineArgs.add_cli_args(parser)
  1048. parser.add_argument('--engine-use-ray',
  1049. action='store_true',
  1050. help='Use Ray to start the LLM engine in a '
  1051. 'separate process as the server process.')
  1052. parser.add_argument('--disable-log-requests',
  1053. action='store_true',
  1054. help='Disable logging requests.')
  1055. parser.add_argument(
  1056. "--uvloop",
  1057. action="store_true",
  1058. help="Use the Uvloop asyncio event loop to possibly increase "
  1059. "performance")
  1060. return parser
  1061. class StoreBoolean(argparse.Action):
  1062. def __call__(self, parser, namespace, values, option_string=None):
  1063. if values.lower() == "true":
  1064. setattr(namespace, self.dest, True)
  1065. elif values.lower() == "false":
  1066. setattr(namespace, self.dest, False)
  1067. else:
  1068. raise ValueError(f"Invalid boolean value: {values}. "
  1069. "Expected 'true' or 'false'.")