args_tools.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124
  1. import argparse
  2. import dataclasses
  3. import json
  4. from dataclasses import dataclass
  5. from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
  6. Union)
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  9. EngineConfig, LoadConfig, LoRAConfig,
  10. ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig, TokenizerPoolConfig)
  13. from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
  14. from aphrodite.executor.executor_base import ExecutorBase
  15. from aphrodite.quantization import QUANTIZATION_METHODS
  16. from aphrodite.transformers_utils.utils import check_gguf_file
  17. if TYPE_CHECKING:
  18. from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
  19. def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
  20. if len(val) == 0:
  21. return None
  22. out_dict: Dict[str, int] = {}
  23. for item in val.split(","):
  24. try:
  25. key, value = item.split("=")
  26. except TypeError as exc:
  27. msg = "Each item should be in the form KEY=VALUE"
  28. raise ValueError(msg) from exc
  29. try:
  30. out_dict[key] = int(value)
  31. except ValueError as exc:
  32. msg = f"Failed to parse value of item {key}={value}"
  33. raise ValueError(msg) from exc
  34. return out_dict
  35. @dataclass
  36. class EngineArgs:
  37. """Arguments for Aphrodite engine."""
  38. # Model Options
  39. model: str
  40. seed: int = 0
  41. served_model_name: Optional[Union[str, List[str]]] = None
  42. tokenizer: Optional[str] = None
  43. revision: Optional[str] = None
  44. code_revision: Optional[str] = None
  45. tokenizer_revision: Optional[str] = None
  46. tokenizer_mode: str = "auto"
  47. trust_remote_code: bool = False
  48. download_dir: Optional[str] = None
  49. max_model_len: Optional[int] = None
  50. max_context_len_to_capture: Optional[int] = None
  51. max_seq_len_to_capture: int = 8192
  52. rope_scaling: Optional[dict] = None
  53. rope_theta: Optional[float] = None
  54. model_loader_extra_config: Optional[dict] = None
  55. enforce_eager: Optional[bool] = None
  56. skip_tokenizer_init: bool = False
  57. tokenizer_pool_size: int = 0
  58. # Note: Specifying a tokenizer pool by passing a class
  59. # is intended for expert use only. The API may change without
  60. # notice.
  61. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
  62. tokenizer_pool_extra_config: Optional[dict] = None
  63. limit_mm_per_prompt: Optional[Mapping[str, int]] = None
  64. max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
  65. # Device Options
  66. device: str = "auto"
  67. # Load Options
  68. load_format: str = "auto"
  69. dtype: str = "auto"
  70. ignore_patterns: Optional[Union[str, List[str]]] = None
  71. # Parallel Options
  72. worker_use_ray: Optional[bool] = False
  73. tensor_parallel_size: int = 1
  74. pipeline_parallel_size: int = 1
  75. ray_workers_use_nsight: bool = False
  76. disable_custom_all_reduce: bool = False
  77. # Note: Specifying a custom executor backend by passing a class
  78. # is intended for expert use only. The API may change without
  79. # notice.
  80. distributed_executor_backend: Optional[Union[str,
  81. Type[ExecutorBase]]] = None
  82. max_parallel_loading_workers: Optional[int] = None
  83. # Quantization Options
  84. quantization: Optional[str] = None
  85. quantization_param_path: Optional[str] = None
  86. preemption_mode: Optional[str] = None
  87. deepspeed_fp_bits: Optional[int] = None
  88. quant_llm_fp_bits: Optional[int] = None
  89. quant_llm_exp_bits: Optional[int] = None
  90. # Cache Options
  91. kv_cache_dtype: str = "auto"
  92. block_size: int = 16
  93. enable_prefix_caching: Optional[bool] = False
  94. num_gpu_blocks_override: Optional[int] = None
  95. disable_sliding_window: bool = False
  96. gpu_memory_utilization: float = 0.90
  97. swap_space: float = 4 # GiB
  98. cpu_offload_gb: float = 0 # GiB
  99. # Scheduler Options
  100. use_v2_block_manager: bool = False
  101. scheduler_delay_factor: float = 0.0
  102. enable_chunked_prefill: bool = False
  103. guided_decoding_backend: str = 'outlines'
  104. max_num_batched_tokens: Optional[int] = None
  105. max_num_seqs: int = 256
  106. num_scheduler_steps: int = 1
  107. # Speculative Decoding Options
  108. num_lookahead_slots: int = 0
  109. speculative_model: Optional[str] = None
  110. speculative_model_quantization: Optional[str] = None
  111. num_speculative_tokens: Optional[int] = None
  112. speculative_max_model_len: Optional[int] = None
  113. ngram_prompt_lookup_max: Optional[int] = None
  114. ngram_prompt_lookup_min: Optional[int] = None
  115. speculative_draft_tensor_parallel_size: Optional[int] = None
  116. speculative_disable_by_batch_size: Optional[int] = None
  117. spec_decoding_acceptance_method: str = 'rejection_sampler'
  118. typical_acceptance_sampler_posterior_threshold: Optional[float] = None
  119. typical_acceptance_sampler_posterior_alpha: Optional[float] = None
  120. disable_logprobs_during_spec_decoding: Optional[bool] = None
  121. # Adapter Options
  122. enable_lora: bool = False
  123. max_loras: int = 1
  124. max_lora_rank: int = 16
  125. lora_extra_vocab_size: int = 256
  126. lora_dtype: str = "auto"
  127. max_cpu_loras: Optional[int] = None
  128. long_lora_scaling_factors: Optional[Tuple[float]] = None
  129. fully_sharded_loras: bool = False
  130. qlora_adapter_name_or_path: Optional[str] = None
  131. enable_prompt_adapter: bool = False
  132. max_prompt_adapters: int = 1
  133. max_prompt_adapter_token: int = 0
  134. # Log Options
  135. disable_log_stats: bool = False
  136. def __post_init__(self):
  137. if self.tokenizer is None:
  138. self.tokenizer = self.model
  139. if is_cpu():
  140. self.distributed_executor_backend = None
  141. @staticmethod
  142. def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
  143. """Shared CLI arguments for the Aphrodite engine."""
  144. # Model Options
  145. parser.add_argument(
  146. "--model",
  147. type=str,
  148. default="EleutherAI/pythia-70m-deduped",
  149. help="Category: Model Options\n"
  150. "name or path of the huggingface model to use",
  151. )
  152. parser.add_argument("--seed",
  153. type=int,
  154. default=EngineArgs.seed,
  155. help="Category: Model Options\n"
  156. "random seed")
  157. parser.add_argument(
  158. "--served-model-name",
  159. nargs="+",
  160. type=str,
  161. default=None,
  162. help="Category: API Options\n"
  163. "The model name(s) used in the API. If multiple "
  164. "names are provided, the server will respond to any "
  165. "of the provided names. The model name in the model "
  166. "field of a response will be the first name in this "
  167. "list. If not specified, the model name will be the "
  168. "same as the `--model` argument. Noted that this name(s)"
  169. "will also be used in `model_name` tag content of "
  170. "prometheus metrics, if multiple names provided, metrics"
  171. "tag will take the first one.")
  172. parser.add_argument(
  173. "--tokenizer",
  174. type=str,
  175. default=EngineArgs.tokenizer,
  176. help="Category: Model Options\n"
  177. "name or path of the huggingface tokenizer to use",
  178. )
  179. parser.add_argument(
  180. "--revision",
  181. type=str,
  182. default=None,
  183. help="Category: Model Options\n"
  184. "the specific model version to use. It can be a branch "
  185. "name, a tag name, or a commit id. If unspecified, will use "
  186. "the default version.",
  187. )
  188. parser.add_argument(
  189. "--code-revision",
  190. type=str,
  191. default=None,
  192. help="Category: Model Options\n"
  193. "the specific revision to use for the model code on "
  194. "Hugging Face Hub. It can be a branch name, a tag name, or a "
  195. "commit id. If unspecified, will use the default version.",
  196. )
  197. parser.add_argument(
  198. "--tokenizer-revision",
  199. type=str,
  200. default=None,
  201. help="Category: Model Options\n"
  202. "the specific tokenizer version to use. It can be a branch "
  203. "name, a tag name, or a commit id. If unspecified, will use "
  204. "the default version.",
  205. )
  206. parser.add_argument(
  207. "--tokenizer-mode",
  208. type=str,
  209. default=EngineArgs.tokenizer_mode,
  210. choices=['auto', 'slow', 'mistral'],
  211. help='The tokenizer mode.\n\n* "auto" will use the '
  212. 'fast tokenizer if available.\n* "slow" will '
  213. 'always use the slow tokenizer. \n* '
  214. '"mistral" will always use the `mistral_common` tokenizer.')
  215. parser.add_argument(
  216. "--trust-remote-code",
  217. action="store_true",
  218. help="Category: Model Options\n"
  219. "trust remote code from huggingface",
  220. )
  221. parser.add_argument(
  222. "--download-dir",
  223. type=str,
  224. default=EngineArgs.download_dir,
  225. help="Category: Model Options\n"
  226. "directory to download and load the weights, "
  227. "default to the default cache dir of "
  228. "huggingface",
  229. )
  230. parser.add_argument(
  231. "--max-model-len",
  232. type=int,
  233. default=EngineArgs.max_model_len,
  234. help="Category: Model Options\n"
  235. "model context length. If unspecified, "
  236. "will be automatically derived from the model.",
  237. )
  238. parser.add_argument("--max-context-len-to-capture",
  239. type=int,
  240. default=EngineArgs.max_context_len_to_capture,
  241. help="Category: Model Options\n"
  242. "Maximum context length covered by CUDA "
  243. "graphs. When a sequence has context length "
  244. "larger than this, we fall back to eager mode. "
  245. "(DEPRECATED. Use --max-seq_len-to-capture instead"
  246. ")")
  247. parser.add_argument("--max-seq_len-to-capture",
  248. type=int,
  249. default=EngineArgs.max_seq_len_to_capture,
  250. help="Category: Model Options\n"
  251. "Maximum sequence length covered by CUDA "
  252. "graphs. When a sequence has context length "
  253. "larger than this, we fall back to eager mode.")
  254. parser.add_argument('--rope-scaling',
  255. default=None,
  256. type=json.loads,
  257. help='Category: Model Options\n'
  258. 'RoPE scaling configuration in JSON format. '
  259. 'For example, {"type":"dynamic","factor":2.0}')
  260. parser.add_argument('--rope-theta',
  261. default=None,
  262. type=float,
  263. help='Category: Model Options\n'
  264. 'RoPE theta. Use with `rope_scaling`. In '
  265. 'some cases, changing the RoPE theta improves the '
  266. 'performance of the scaled model.')
  267. parser.add_argument("--model-loader-extra-config",
  268. type=str,
  269. default=EngineArgs.model_loader_extra_config,
  270. help="Category: Model Options\n"
  271. "Extra config for model loader. "
  272. "This will be passed to the model loader "
  273. "corresponding to the chosen load_format. "
  274. "This should be a JSON string that will be "
  275. "parsed into a dictionary.")
  276. parser.add_argument(
  277. "--enforce-eager",
  278. action=StoreBoolean,
  279. default=EngineArgs.enforce_eager,
  280. nargs="?",
  281. const="True",
  282. help="Category: Model Options\n"
  283. "Always use eager-mode PyTorch. If False, "
  284. "will use eager mode and CUDA graph in hybrid "
  285. "for maximal performance and flexibility.",
  286. )
  287. parser.add_argument("--skip-tokenizer-init",
  288. action="store_true",
  289. help="Category: Model Options\n"
  290. "Skip initialization of tokenizer and detokenizer")
  291. parser.add_argument("--tokenizer-pool-size",
  292. type=int,
  293. default=EngineArgs.tokenizer_pool_size,
  294. help="Category: Model Options\n"
  295. "Size of tokenizer pool to use for "
  296. "asynchronous tokenization. If 0, will "
  297. "use synchronous tokenization.")
  298. parser.add_argument("--tokenizer-pool-type",
  299. type=str,
  300. default=EngineArgs.tokenizer_pool_type,
  301. help="Category: Model Options\n"
  302. "The type of tokenizer pool to use for "
  303. "asynchronous tokenization. Ignored if "
  304. "tokenizer_pool_size is 0.")
  305. parser.add_argument("--tokenizer-pool-extra-config",
  306. type=str,
  307. default=EngineArgs.tokenizer_pool_extra_config,
  308. help="Category: Model Options\n"
  309. "Extra config for tokenizer pool. "
  310. "This should be a JSON string that will be "
  311. "parsed into a dictionary. Ignored if "
  312. "tokenizer_pool_size is 0.")
  313. # Multimodal related configs
  314. parser.add_argument(
  315. '--limit-mm-per-prompt',
  316. type=nullable_kvs,
  317. default=EngineArgs.limit_mm_per_prompt,
  318. # The default value is given in
  319. # MultiModalRegistry.init_mm_limits_per_prompt
  320. help=('For each multimodal plugin, limit how many '
  321. 'input instances to allow for each prompt. '
  322. 'Expects a comma-separated list of items, '
  323. 'e.g.: `image=16,video=2` allows a maximum of 16 '
  324. 'images and 2 videos per prompt. Defaults to 1 for '
  325. 'each modality.'))
  326. parser.add_argument(
  327. "--max-logprobs",
  328. type=int,
  329. default=EngineArgs.max_logprobs,
  330. help="Category: Model Options\n"
  331. "maximum number of log probabilities to "
  332. "return.",
  333. )
  334. # Device Options
  335. parser.add_argument(
  336. "--device",
  337. type=str,
  338. default=EngineArgs.device,
  339. choices=[
  340. "auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"
  341. ],
  342. help=("Category: Model Options\n"
  343. "Device to use for model execution."),
  344. )
  345. # Load Options
  346. parser.add_argument(
  347. '--load-format',
  348. type=str,
  349. default=EngineArgs.load_format,
  350. choices=[
  351. 'auto',
  352. 'pt',
  353. 'safetensors',
  354. 'npcache',
  355. 'dummy',
  356. 'tensorizer',
  357. 'sharded_state',
  358. 'bitsandbytes',
  359. ],
  360. help='Category: Model Options\n'
  361. 'The format of the model weights to load.\n\n'
  362. '* "auto" will try to load the weights in the safetensors format '
  363. 'and fall back to the pytorch bin format if safetensors format '
  364. 'is not available.\n'
  365. '* "pt" will load the weights in the pytorch bin format.\n'
  366. '* "safetensors" will load the weights in the safetensors format.\n'
  367. '* "npcache" will load the weights in pytorch format and store '
  368. 'a numpy cache to speed up the loading.\n'
  369. '* "dummy" will initialize the weights with random values, '
  370. 'which is mainly for profiling.\n'
  371. '* "tensorizer" will load the weights using tensorizer from '
  372. 'CoreWeave. See the Tensorize Aphrodite Model script in the '
  373. 'Examples section for more information.\n'
  374. '* "bitsandbytes" will load the weights using bitsandbytes '
  375. 'quantization.\n')
  376. parser.add_argument(
  377. '--dtype',
  378. type=str,
  379. default=EngineArgs.dtype,
  380. choices=[
  381. 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
  382. ],
  383. help='Category: Model Options\n'
  384. 'Data type for model weights and activations.\n\n'
  385. '* "auto" will use FP16 precision for FP32 and FP16 models, and '
  386. 'BF16 precision for BF16 models.\n'
  387. '* "half" for FP16. Recommended for AWQ quantization.\n'
  388. '* "float16" is the same as "half".\n'
  389. '* "bfloat16" for a balance between precision and range.\n'
  390. '* "float" is shorthand for FP32 precision.\n'
  391. '* "float32" for FP32 precision.')
  392. parser.add_argument(
  393. '--ignore-patterns',
  394. action="append",
  395. type=str,
  396. default=[],
  397. help="Category: Model Options\n"
  398. "The pattern(s) to ignore when loading the model."
  399. "Defaults to 'original/**/*' to avoid repeated loading of llama's "
  400. "checkpoints.")
  401. # Parallel Options
  402. parser.add_argument(
  403. '--worker-use-ray',
  404. action='store_true',
  405. help='Category: Parallel Options\n'
  406. 'Deprecated, use --distributed-executor-backend=ray.')
  407. parser.add_argument(
  408. "--tensor-parallel-size",
  409. "-tp",
  410. type=int,
  411. default=EngineArgs.tensor_parallel_size,
  412. help="Category: Parallel Options\n"
  413. "number of tensor parallel replicas, i.e. the number of GPUs "
  414. "to use.")
  415. parser.add_argument(
  416. "--pipeline-parallel-size",
  417. "-pp",
  418. type=int,
  419. default=EngineArgs.pipeline_parallel_size,
  420. help="Category: Parallel Options\n"
  421. "number of pipeline stages. Currently not supported.")
  422. parser.add_argument(
  423. "--ray-workers-use-nsight",
  424. action="store_true",
  425. help="Category: Parallel Options\n"
  426. "If specified, use nsight to profile ray workers",
  427. )
  428. parser.add_argument(
  429. "--disable-custom-all-reduce",
  430. action="store_true",
  431. default=EngineArgs.disable_custom_all_reduce,
  432. help="Category: Model Options\n"
  433. "See ParallelConfig",
  434. )
  435. parser.add_argument(
  436. '--distributed-executor-backend',
  437. choices=['ray', 'mp'],
  438. default=EngineArgs.distributed_executor_backend,
  439. help='Category: Parallel Options\n'
  440. 'Backend to use for distributed serving. When more than 1 GPU '
  441. 'is used, will be automatically set to "ray" if installed '
  442. 'or "mp" (multiprocessing) otherwise.')
  443. parser.add_argument(
  444. "--max-parallel-loading-workers",
  445. type=int,
  446. default=EngineArgs.max_parallel_loading_workers,
  447. help="Category: Parallel Options\n"
  448. "load model sequentially in multiple batches, "
  449. "to avoid RAM OOM when using tensor "
  450. "parallel and large models",
  451. )
  452. # Quantization Options
  453. parser.add_argument(
  454. "--quantization",
  455. "-q",
  456. type=str,
  457. choices=[*QUANTIZATION_METHODS, None],
  458. default=EngineArgs.quantization,
  459. help="Category: Quantization Options\n"
  460. "Method used to quantize the weights. If "
  461. "None, we first check the `quantization_config` "
  462. "attribute in the model config file. If that is "
  463. "None, we assume the model weights are not "
  464. "quantized and use `dtype` to determine the data "
  465. "type of the weights.",
  466. )
  467. parser.add_argument(
  468. '--quantization-param-path',
  469. type=str,
  470. default=None,
  471. help='Category: Quantization Options\n'
  472. 'Path to the JSON file containing the KV cache '
  473. 'scaling factors. This should generally be supplied, when '
  474. 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
  475. 'default to 1.0, which may cause accuracy issues. '
  476. 'FP8_E5M2 (without scaling) is only supported on cuda version'
  477. 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
  478. 'supported for common inference criteria. ')
  479. parser.add_argument(
  480. '--preemption-mode',
  481. type=str,
  482. default=None,
  483. help='Category: Scheduler Options\n'
  484. 'If \'recompute\', the engine performs preemption by block '
  485. 'swapping; If \'swap\', the engine performs preemption by block '
  486. 'swapping.')
  487. parser.add_argument("--deepspeed-fp-bits",
  488. type=int,
  489. default=None,
  490. help="Category: Quantization Options\n"
  491. "Number of floating bits to use for the deepspeed "
  492. "quantization. Supported bits are: 4, 6, 8, 12.")
  493. parser.add_argument("--quant-llm-fp-bits",
  494. type=int,
  495. default=None,
  496. help="Category: Quantization Options\n"
  497. "Number of floating bits to use for the quant_llm "
  498. "quantization. Supported bits are: 4 to 15.")
  499. parser.add_argument("--quant-llm-exp-bits",
  500. type=int,
  501. default=None,
  502. help="Category: Quantization Options\n"
  503. "Number of exponent bits to use for the quant_llm "
  504. "quantization. Supported bits are: 1 to 5.")
  505. # Cache Options
  506. parser.add_argument(
  507. '--kv-cache-dtype',
  508. type=str,
  509. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  510. default=EngineArgs.kv_cache_dtype,
  511. help='Category: Cache Options\n'
  512. 'Data type for kv cache storage. If "auto", will use model '
  513. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  514. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  515. parser.add_argument(
  516. "--block-size",
  517. type=int,
  518. default=EngineArgs.block_size,
  519. choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
  520. help="Category: Cache Options\n"
  521. "token block size",
  522. )
  523. parser.add_argument(
  524. "--enable-prefix-caching",
  525. "--context-shift",
  526. action="store_true",
  527. help="Category: Cache Options\n"
  528. "Enable automatic prefix caching.",
  529. )
  530. parser.add_argument(
  531. "--num-gpu-blocks-override",
  532. type=int,
  533. default=None,
  534. help="Category: Cache Options Options\n"
  535. "If specified, ignore GPU profiling result and use this "
  536. "number of GPU blocks. Used for testing preemption.")
  537. parser.add_argument('--disable-sliding-window',
  538. action='store_true',
  539. help='Category: KV Cache Options\n'
  540. 'Disables sliding window, '
  541. 'capping to sliding window size')
  542. parser.add_argument(
  543. "--gpu-memory-utilization",
  544. "-gmu",
  545. type=float,
  546. default=EngineArgs.gpu_memory_utilization,
  547. help="Category: Cache Options\n"
  548. "The fraction of GPU memory to be used for "
  549. "the model executor, which can range from 0 to 1."
  550. "If unspecified, will use the default value of 0.9.",
  551. )
  552. parser.add_argument(
  553. "--swap-space",
  554. type=float,
  555. default=EngineArgs.swap_space,
  556. help="Category: Cache Options\n"
  557. "CPU swap space size (GiB) per GPU",
  558. )
  559. parser.add_argument(
  560. '--cpu-offload-gb',
  561. type=float,
  562. default=0,
  563. help='Category: Cache Options\n'
  564. 'The space in GiB to offload to CPU, per GPU. '
  565. 'Default is 0, which means no offloading. Intuitively, '
  566. 'this argument can be seen as a virtual way to increase '
  567. 'the GPU memory size. For example, if you have one 24 GB '
  568. 'GPU and set this to 10, virtually you can think of it as '
  569. 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
  570. 'which requires at least 26GB GPU memory. Note that this '
  571. 'requires fast CPU-GPU interconnect, as part of the model is'
  572. 'loaded from CPU memory to GPU memory on the fly in each '
  573. 'model forward pass.')
  574. # Scheduler Options
  575. parser.add_argument("--use-v2-block-manager",
  576. action="store_true",
  577. help="Category: Scheduler Options\n"
  578. "Use the v2 block manager.")
  579. parser.add_argument(
  580. "--scheduler-delay-factor",
  581. "-sdf",
  582. type=float,
  583. default=EngineArgs.scheduler_delay_factor,
  584. help="Category: Scheduler Options\n"
  585. "Apply a delay (of delay factor multiplied by previous "
  586. "prompt latency) before scheduling next prompt.")
  587. parser.add_argument(
  588. "--enable-chunked-prefill",
  589. action=StoreBoolean,
  590. default=EngineArgs.enable_chunked_prefill,
  591. nargs="?",
  592. const="True",
  593. help="Category: Scheduler Options\n"
  594. "If True, the prefill requests can be chunked based on the "
  595. "max_num_batched_tokens.")
  596. parser.add_argument(
  597. '--guided-decoding-backend',
  598. type=str,
  599. default='outlines',
  600. choices=['outlines', 'lm-format-enforcer'],
  601. help='Category: Scheduler Options\n'
  602. 'Which engine will be used for guided decoding'
  603. ' (JSON schema / regex etc) by default. Currently support '
  604. 'https://github.com/outlines-dev/outlines and '
  605. 'https://github.com/noamgat/lm-format-enforcer.'
  606. ' Can be overridden per request via guided_decoding_backend'
  607. ' parameter.')
  608. parser.add_argument(
  609. "--max-num-batched-tokens",
  610. type=int,
  611. default=EngineArgs.max_num_batched_tokens,
  612. help="Category: KV Cache Options\n"
  613. "maximum number of batched tokens per "
  614. "iteration",
  615. )
  616. parser.add_argument(
  617. "--max-num-seqs",
  618. type=int,
  619. default=EngineArgs.max_num_seqs,
  620. help="Category: API Options\n"
  621. "maximum number of sequences per iteration",
  622. )
  623. parser.add_argument('--num-scheduler-steps',
  624. type=int,
  625. default=1,
  626. help=('Maximum number of forward steps per '
  627. 'scheduler call.'))
  628. # Speculative Decoding Options
  629. parser.add_argument("--num-lookahead-slots",
  630. type=int,
  631. default=EngineArgs.num_lookahead_slots,
  632. help="Category: Speculative Decoding Options\n"
  633. "Experimental scheduling config necessary for "
  634. "speculative decoding. This will be replaced by "
  635. "speculative decoding config in the future; it is "
  636. "present for testing purposes until then.")
  637. parser.add_argument(
  638. "--speculative-model",
  639. type=str,
  640. default=EngineArgs.speculative_model,
  641. help="Category: Speculative Decoding Options\n"
  642. "The name of the draft model to be used in speculative decoding.")
  643. # Quantization settings for speculative model.
  644. parser.add_argument(
  645. '--speculative-model-quantization',
  646. type=str,
  647. choices=[*QUANTIZATION_METHODS, None],
  648. default=EngineArgs.speculative_model_quantization,
  649. help='Method used to quantize the weights of speculative model.'
  650. 'If None, we first check the `quantization_config` '
  651. 'attribute in the model config file. If that is '
  652. 'None, we assume the model weights are not '
  653. 'quantized and use `dtype` to determine the data '
  654. 'type of the weights.')
  655. parser.add_argument("--num-speculative-tokens",
  656. type=int,
  657. default=EngineArgs.num_speculative_tokens,
  658. help="Category: Speculative Decoding Options\n"
  659. "The number of speculative tokens to sample from "
  660. "the draft model in speculative decoding")
  661. parser.add_argument(
  662. "--speculative-max-model-len",
  663. type=str,
  664. default=EngineArgs.speculative_max_model_len,
  665. help="Category: Speculative Decoding Options\n"
  666. "The maximum sequence length supported by the "
  667. "draft model. Sequences over this length will skip "
  668. "speculation.")
  669. parser.add_argument(
  670. "--ngram-prompt-lookup-max",
  671. type=int,
  672. default=EngineArgs.ngram_prompt_lookup_max,
  673. help="Category: Speculative Decoding Options\n"
  674. "Max size of window for ngram prompt lookup in speculative "
  675. "decoding.")
  676. parser.add_argument(
  677. "--ngram-prompt-lookup-min",
  678. type=int,
  679. default=EngineArgs.ngram_prompt_lookup_min,
  680. help="Category: Speculative Decoding Options\n"
  681. "Min size of window for ngram prompt lookup in speculative "
  682. "decoding.")
  683. parser.add_argument(
  684. "--speculative-draft-tensor-parallel-size",
  685. "-spec-draft-tp",
  686. type=int,
  687. default=EngineArgs.speculative_draft_tensor_parallel_size,
  688. help="Category: Speculative Decoding Options\n"
  689. "Number of tensor parallel replicas for "
  690. "the draft model in speculative decoding.")
  691. parser.add_argument(
  692. "--speculative-disable-by-batch-size",
  693. type=int,
  694. default=EngineArgs.speculative_disable_by_batch_size,
  695. help="Category: Speculative Decoding Options\n"
  696. "Disable speculative decoding for new incoming requests "
  697. "if the number of enqueue requests is larger than this value.")
  698. parser.add_argument(
  699. '--spec-decoding-acceptance-method',
  700. type=str,
  701. default=EngineArgs.spec_decoding_acceptance_method,
  702. choices=['rejection_sampler', 'typical_acceptance_sampler'],
  703. help='Category: Speculative Decoding Options\n'
  704. 'Specify the acceptance method to use during draft token '
  705. 'verification in speculative decoding. Two types of acceptance '
  706. 'routines are supported: '
  707. '1) RejectionSampler which does not allow changing the '
  708. 'acceptance rate of draft tokens, '
  709. '2) TypicalAcceptanceSampler which is configurable, allowing for '
  710. 'a higher acceptance rate at the cost of lower quality, '
  711. 'and vice versa.')
  712. parser.add_argument(
  713. '--typical-acceptance-sampler-posterior-threshold',
  714. type=float,
  715. default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
  716. help='Category: Speculative Decoding Options\n'
  717. 'Set the lower bound threshold for the posterior '
  718. 'probability of a token to be accepted. This threshold is '
  719. 'used by the TypicalAcceptanceSampler to make sampling decisions '
  720. 'during speculative decoding. Defaults to 0.09')
  721. parser.add_argument(
  722. '--typical-acceptance-sampler-posterior-alpha',
  723. type=float,
  724. default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
  725. help='Category: Speculative Decoding Options\n'
  726. 'A scaling factor for the entropy-based threshold for token '
  727. 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
  728. 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
  729. 'i.e. 0.3')
  730. parser.add_argument(
  731. '--disable-logprobs-during-spec-decoding',
  732. type=bool,
  733. default=EngineArgs.disable_logprobs_during_spec_decoding,
  734. help='Category: Speculative Decoding Options\n'
  735. 'If set to True, token log probabilities are not returned '
  736. 'during speculative decoding. If set to False, log probabilities '
  737. 'are returned according to the settings in SamplingParams. If '
  738. 'not specified, it defaults to True. Disabling log probabilities '
  739. 'during speculative decoding reduces latency by skipping logprob '
  740. 'calculation in proposal sampling, target sampling, and after '
  741. 'accepted tokens are determined.')
  742. # Adapter Options
  743. parser.add_argument(
  744. "--enable-lora",
  745. action="store_true",
  746. help="Category: Adapter Options\n"
  747. "If True, enable handling of LoRA adapters.",
  748. )
  749. parser.add_argument(
  750. "--max-loras",
  751. type=int,
  752. default=EngineArgs.max_loras,
  753. help="Category: Adapter Options\n"
  754. "Max number of LoRAs in a single batch.",
  755. )
  756. parser.add_argument(
  757. "--max-lora-rank",
  758. type=int,
  759. default=EngineArgs.max_lora_rank,
  760. help="Category: Adapter Options\n"
  761. "Max LoRA rank.",
  762. )
  763. parser.add_argument(
  764. "--lora-extra-vocab-size",
  765. type=int,
  766. default=EngineArgs.lora_extra_vocab_size,
  767. help=("Category: Adapter Options\n"
  768. "Maximum size of extra vocabulary that can be "
  769. "present in a LoRA adapter (added to the base "
  770. "model vocabulary)."),
  771. )
  772. parser.add_argument(
  773. "--lora-dtype",
  774. type=str,
  775. default=EngineArgs.lora_dtype,
  776. choices=["auto", "float16", "bfloat16", "float32"],
  777. help=("Category: Adapter Options\n"
  778. "Data type for LoRA. If auto, will default to "
  779. "base model dtype."),
  780. )
  781. parser.add_argument(
  782. "--max-cpu-loras",
  783. type=int,
  784. default=EngineArgs.max_cpu_loras,
  785. help=("Category: Adapter Options\n"
  786. "Maximum number of LoRAs to store in CPU memory. "
  787. "Must be >= than max_num_seqs. "
  788. "Defaults to max_num_seqs."),
  789. )
  790. parser.add_argument(
  791. "--long-lora-scaling-factors",
  792. type=str,
  793. default=EngineArgs.long_lora_scaling_factors,
  794. help=("Category: Adapter Options\n"
  795. "Specify multiple scaling factors (which can "
  796. "be different from base model scaling factor "
  797. "- see eg. Long LoRA) to allow for multiple "
  798. "LoRA adapters trained with those scaling "
  799. "factors to be used at the same time. If not "
  800. "specified, only adapters trained with the "
  801. "base model scaling factor are allowed."))
  802. parser.add_argument(
  803. "--fully-sharded-loras",
  804. action='store_true',
  805. help=("Category: Adapter Options\n"
  806. "By default, only half of the LoRA computation is sharded "
  807. "with tensor parallelism. Enabling this will use the fully "
  808. "sharded layers. At high sequence length, max rank or "
  809. "tensor parallel size, this is likely faster."))
  810. parser.add_argument("--qlora-adapter-name-or-path",
  811. type=str,
  812. default=None,
  813. help="Category: Adapter Options\n"
  814. "Name or path of the LoRA adapter to use.")
  815. parser.add_argument('--enable-prompt-adapter',
  816. action='store_true',
  817. help='Category: Adapter Options\n'
  818. 'If True, enable handling of PromptAdapters.')
  819. parser.add_argument('--max-prompt-adapters',
  820. type=int,
  821. default=EngineArgs.max_prompt_adapters,
  822. help='Category: Adapter Options\n'
  823. 'Max number of PromptAdapters in a batch.')
  824. parser.add_argument('--max-prompt-adapter-token',
  825. type=int,
  826. default=EngineArgs.max_prompt_adapter_token,
  827. help='Category: Adapter Options\n'
  828. 'Max number of PromptAdapters tokens')
  829. # Log Options
  830. parser.add_argument(
  831. "--disable-log-stats",
  832. action="store_true",
  833. help="Category: Log Options\n"
  834. "disable logging statistics",
  835. )
  836. return parser
  837. @classmethod
  838. def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
  839. # Get the list of attributes of this dataclass.
  840. attrs = [attr.name for attr in dataclasses.fields(cls)]
  841. # Set the attributes from the parsed arguments.
  842. engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
  843. return engine_args
  844. def create_engine_config(self, ) -> EngineConfig:
  845. # gguf file needs a specific model loader and doesn't use hf_repo
  846. if check_gguf_file(self.model):
  847. self.quantization = self.load_format = "gguf"
  848. # bitsandbytes quantization needs a specific model loader
  849. # so we make sure the quant method and the load format are consistent
  850. if (self.quantization == "bitsandbytes" or
  851. self.qlora_adapter_name_or_path is not None) and \
  852. self.load_format != "bitsandbytes":
  853. raise ValueError(
  854. "BitsAndBytes quantization and QLoRA adapter only support "
  855. f"'bitsandbytes' load format, but got {self.load_format}")
  856. if (self.load_format == "bitsandbytes" or
  857. self.qlora_adapter_name_or_path is not None) and \
  858. self.quantization != "bitsandbytes":
  859. raise ValueError(
  860. "BitsAndBytes load format and QLoRA adapter only support "
  861. f"'bitsandbytes' quantization, but got {self.quantization}")
  862. assert self.cpu_offload_gb >= 0, (
  863. "CPU offload space must be non-negative"
  864. f", but got {self.cpu_offload_gb}")
  865. device_config = DeviceConfig(device=self.device)
  866. model_config = ModelConfig(
  867. model=self.model,
  868. tokenizer=self.tokenizer,
  869. tokenizer_mode=self.tokenizer_mode,
  870. trust_remote_code=self.trust_remote_code,
  871. dtype=self.dtype,
  872. seed=self.seed,
  873. revision=self.revision,
  874. code_revision=self.code_revision,
  875. rope_scaling=self.rope_scaling,
  876. rope_theta=self.rope_theta,
  877. tokenizer_revision=self.tokenizer_revision,
  878. max_model_len=self.max_model_len,
  879. quantization=self.quantization,
  880. deepspeed_fp_bits=self.deepspeed_fp_bits,
  881. quant_llm_fp_bits=self.quant_llm_fp_bits,
  882. quant_llm_exp_bits=self.quant_llm_exp_bits,
  883. quantization_param_path=self.quantization_param_path,
  884. enforce_eager=self.enforce_eager,
  885. max_context_len_to_capture=self.max_context_len_to_capture,
  886. max_seq_len_to_capture=self.max_seq_len_to_capture,
  887. max_logprobs=self.max_logprobs,
  888. disable_sliding_window=self.disable_sliding_window,
  889. skip_tokenizer_init=self.skip_tokenizer_init,
  890. served_model_name=self.served_model_name,
  891. limit_mm_per_prompt=self.limit_mm_per_prompt,
  892. )
  893. cache_config = CacheConfig(
  894. block_size=self.block_size,
  895. gpu_memory_utilization=self.gpu_memory_utilization,
  896. swap_space=self.swap_space,
  897. cache_dtype=self.kv_cache_dtype,
  898. is_attention_free=model_config.is_attention_free(),
  899. num_gpu_blocks_override=self.num_gpu_blocks_override,
  900. sliding_window=model_config.get_sliding_window(),
  901. enable_prefix_caching=self.enable_prefix_caching,
  902. cpu_offload_gb=self.cpu_offload_gb,
  903. )
  904. parallel_config = ParallelConfig(
  905. pipeline_parallel_size=self.pipeline_parallel_size,
  906. tensor_parallel_size=self.tensor_parallel_size,
  907. worker_use_ray=self.worker_use_ray,
  908. max_parallel_loading_workers=self.max_parallel_loading_workers,
  909. disable_custom_all_reduce=self.disable_custom_all_reduce,
  910. tokenizer_pool_config=TokenizerPoolConfig.create_config(
  911. tokenizer_pool_size=self.tokenizer_pool_size,
  912. tokenizer_pool_type=self.tokenizer_pool_type,
  913. tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
  914. ),
  915. ray_workers_use_nsight=self.ray_workers_use_nsight,
  916. distributed_executor_backend=self.distributed_executor_backend)
  917. max_model_len = model_config.max_model_len
  918. use_long_context = max_model_len > 32768
  919. if self.enable_chunked_prefill is None:
  920. # If not explicitly set, enable chunked prefill by default for
  921. # long context (> 32K) models. This is to avoid OOM errors in the
  922. # initial memory profiling phase.
  923. if use_long_context:
  924. is_gpu = device_config.device_type == "cuda"
  925. use_sliding_window = (model_config.get_sliding_window()
  926. is not None)
  927. use_spec_decode = self.speculative_model is not None
  928. has_seqlen_agnostic_layers = (
  929. model_config.contains_seqlen_agnostic_layers(
  930. parallel_config))
  931. if (is_gpu and not use_sliding_window and not use_spec_decode
  932. and not self.enable_lora
  933. and not self.enable_prompt_adapter
  934. and not self.enable_prefix_caching
  935. and not has_seqlen_agnostic_layers):
  936. self.enable_chunked_prefill = True
  937. logger.warning(
  938. "Chunked prefill is enabled by default for models with "
  939. "max_model_len > 32K. Currently, chunked prefill might "
  940. "not work with some features or models. If you "
  941. "encounter any issues, please disable chunked prefill "
  942. "by setting --enable-chunked-prefill=False.")
  943. if self.enable_chunked_prefill is None:
  944. self.enable_chunked_prefill = False
  945. if not self.enable_chunked_prefill and use_long_context:
  946. logger.warning(
  947. f"The model has a long context length ({max_model_len}). "
  948. "This may cause OOM errors during the initial memory "
  949. "profiling phase, or result in low performance due to small "
  950. "KV cache space. Consider setting --max-model-len to a "
  951. "smaller value.")
  952. speculative_config = SpeculativeConfig.maybe_create_spec_config(
  953. target_model_config=model_config,
  954. target_parallel_config=parallel_config,
  955. target_dtype=self.dtype,
  956. speculative_model=self.speculative_model,
  957. speculative_model_quantization = \
  958. self.speculative_model_quantization,
  959. speculative_draft_tensor_parallel_size=self.
  960. speculative_draft_tensor_parallel_size,
  961. num_speculative_tokens=self.num_speculative_tokens,
  962. speculative_disable_by_batch_size=self.
  963. speculative_disable_by_batch_size,
  964. speculative_max_model_len=self.speculative_max_model_len,
  965. enable_chunked_prefill=self.enable_chunked_prefill,
  966. use_v2_block_manager=self.use_v2_block_manager,
  967. disable_log_stats=self.disable_log_stats,
  968. ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
  969. ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
  970. draft_token_acceptance_method=\
  971. self.spec_decoding_acceptance_method,
  972. typical_acceptance_sampler_posterior_threshold=self.
  973. typical_acceptance_sampler_posterior_threshold,
  974. typical_acceptance_sampler_posterior_alpha=self.
  975. typical_acceptance_sampler_posterior_alpha,
  976. disable_logprobs=self.disable_logprobs_during_spec_decoding,
  977. )
  978. if self.num_scheduler_steps > 1:
  979. raise NotImplementedError("Multi-step is not yet supported.")
  980. if speculative_config is not None:
  981. raise ValueError("Speculative decoding is not supported with "
  982. "multi-step (--num-scheduler-steps > 1)")
  983. if self.enable_chunked_prefill:
  984. raise ValueError("Chunked prefill is not supported with "
  985. "multi-step (--num-scheduler-steps > 1)")
  986. # make sure num_lookahead_slots is set the higher value depending on
  987. # if we are using speculative decoding or multi-step
  988. num_lookahead_slots = max(self.num_lookahead_slots,
  989. self.num_scheduler_steps - 1)
  990. num_lookahead_slots = num_lookahead_slots \
  991. if speculative_config is None \
  992. else speculative_config.num_lookahead_slots
  993. scheduler_config = SchedulerConfig(
  994. max_num_batched_tokens=self.max_num_batched_tokens,
  995. max_num_seqs=self.max_num_seqs,
  996. max_model_len=model_config.max_model_len,
  997. is_attention_free=model_config.is_attention_free(),
  998. use_v2_block_manager=self.use_v2_block_manager,
  999. num_lookahead_slots=num_lookahead_slots,
  1000. delay_factor=self.scheduler_delay_factor,
  1001. enable_chunked_prefill=self.enable_chunked_prefill,
  1002. embedding_mode=model_config.embedding_mode,
  1003. preemption_mode=self.preemption_mode,
  1004. num_scheduler_steps=self.num_scheduler_steps,
  1005. )
  1006. lora_config = LoRAConfig(
  1007. max_lora_rank=self.max_lora_rank,
  1008. max_loras=self.max_loras,
  1009. fully_sharded_loras=self.fully_sharded_loras,
  1010. lora_extra_vocab_size=self.lora_extra_vocab_size,
  1011. long_lora_scaling_factors=self.long_lora_scaling_factors,
  1012. lora_dtype=self.lora_dtype,
  1013. max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
  1014. and self.max_cpu_loras > 0 else None) if self.enable_lora else None
  1015. if self.qlora_adapter_name_or_path is not None and \
  1016. self.qlora_adapter_name_or_path != "":
  1017. if self.model_loader_extra_config is None:
  1018. self.model_loader_extra_config = {}
  1019. self.model_loader_extra_config[
  1020. "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
  1021. load_config = LoadConfig(
  1022. load_format=self.load_format,
  1023. download_dir=self.download_dir,
  1024. model_loader_extra_config=self.model_loader_extra_config,
  1025. ignore_patterns=self.ignore_patterns)
  1026. prompt_adapter_config = PromptAdapterConfig(
  1027. max_prompt_adapters=self.max_prompt_adapters,
  1028. max_prompt_adapter_token=self.max_prompt_adapter_token) \
  1029. if self.enable_prompt_adapter else None
  1030. decoding_config = DecodingConfig(
  1031. guided_decoding_backend=self.guided_decoding_backend)
  1032. if (model_config.get_sliding_window() is not None
  1033. and scheduler_config.chunked_prefill_enabled
  1034. and not scheduler_config.use_v2_block_manager):
  1035. raise ValueError(
  1036. "Chunked prefill is not supported with sliding window. "
  1037. "Set --disable-sliding-window to disable sliding window.")
  1038. return EngineConfig(model_config=model_config,
  1039. cache_config=cache_config,
  1040. parallel_config=parallel_config,
  1041. scheduler_config=scheduler_config,
  1042. device_config=device_config,
  1043. lora_config=lora_config,
  1044. speculative_config=speculative_config,
  1045. load_config=load_config,
  1046. decoding_config=decoding_config,
  1047. prompt_adapter_config=prompt_adapter_config)
  1048. @dataclass
  1049. class AsyncEngineArgs(EngineArgs):
  1050. """Arguments for asynchronous Aphrodite engine."""
  1051. engine_use_ray: bool = False
  1052. disable_log_requests: bool = False
  1053. uvloop: bool = False
  1054. @staticmethod
  1055. def add_cli_args(parser: FlexibleArgumentParser,
  1056. async_args_only: bool = False) -> FlexibleArgumentParser:
  1057. if not async_args_only:
  1058. parser = EngineArgs.add_cli_args(parser)
  1059. parser.add_argument('--engine-use-ray',
  1060. action='store_true',
  1061. help='Use Ray to start the LLM engine in a '
  1062. 'separate process as the server process.')
  1063. parser.add_argument('--disable-log-requests',
  1064. action='store_true',
  1065. help='Disable logging requests.')
  1066. parser.add_argument(
  1067. "--uvloop",
  1068. action="store_true",
  1069. help="Use the Uvloop asyncio event loop to possibly increase "
  1070. "performance")
  1071. return parser
  1072. class StoreBoolean(argparse.Action):
  1073. def __call__(self, parser, namespace, values, option_string=None):
  1074. if values.lower() == "true":
  1075. setattr(namespace, self.dest, True)
  1076. elif values.lower() == "false":
  1077. setattr(namespace, self.dest, False)
  1078. else:
  1079. raise ValueError(f"Invalid boolean value: {values}. "
  1080. "Expected 'true' or 'false'.")