args_tools.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924
  1. import argparse
  2. import dataclasses
  3. import json
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
  6. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  7. EngineConfig, LoadConfig, LoRAConfig,
  8. ModelConfig, MultiModalConfig,
  9. ParallelConfig, PromptAdapterConfig,
  10. SchedulerConfig, SpeculativeConfig,
  11. TokenizerPoolConfig)
  12. from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
  13. from aphrodite.executor.executor_base import ExecutorBase
  14. from aphrodite.quantization import QUANTIZATION_METHODS
  15. if TYPE_CHECKING:
  16. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import \
  17. BaseTokenizerGroup # noqa: E501
  18. @dataclass
  19. class EngineArgs:
  20. """Arguments for Aphrodite engine."""
  21. model: str
  22. served_model_name: Optional[Union[List[str]]] = None
  23. tokenizer: Optional[str] = None
  24. skip_tokenizer_init: bool = False
  25. tokenizer_mode: str = "auto"
  26. trust_remote_code: bool = False
  27. download_dir: Optional[str] = None
  28. load_format: str = "auto"
  29. dtype: str = "auto"
  30. kv_cache_dtype: str = "auto"
  31. quantization_param_path: Optional[str] = None
  32. seed: int = 0
  33. max_model_len: Optional[int] = None
  34. worker_use_ray: Optional[bool] = False
  35. # Note: Specifying a custom executor backend by passing a class
  36. # is intended for expert use only. The API may change without
  37. # notice.
  38. distributed_executor_backend: Optional[Union[str,
  39. Type[ExecutorBase]]] = None
  40. pipeline_parallel_size: int = 1
  41. tensor_parallel_size: int = 1
  42. max_parallel_loading_workers: Optional[int] = None
  43. block_size: int = 16
  44. enable_prefix_caching: bool = False
  45. disable_sliding_window: bool = False
  46. use_v2_block_manager: bool = False
  47. swap_space: int = 4 # GiB
  48. cpu_offload_gb: int = 0 # GiB
  49. gpu_memory_utilization: float = 0.90
  50. max_num_batched_tokens: Optional[int] = None
  51. max_num_seqs: int = 256
  52. max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
  53. disable_log_stats: bool = False
  54. revision: Optional[str] = None
  55. code_revision: Optional[str] = None
  56. rope_scaling: Optional[dict] = None
  57. rope_theta: Optional[float] = None
  58. tokenizer_revision: Optional[str] = None
  59. quantization: Optional[str] = None
  60. load_in_4bit: bool = False
  61. load_in_8bit: bool = False
  62. load_in_smooth: bool = False
  63. deepspeed_fp_bits: Optional[int] = None
  64. enforce_eager: bool = True
  65. max_context_len_to_capture: Optional[int] = None
  66. max_seq_len_to_capture: int = 8192
  67. disable_custom_all_reduce: bool = False
  68. tokenizer_pool_size: int = 0
  69. # Note: Specifying a tokenizer pool by passing a class
  70. # is intended for expert use only. The API may change without
  71. # notice.
  72. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
  73. tokenizer_pool_extra_config: Optional[dict] = None
  74. enable_lora: bool = False
  75. max_loras: int = 1
  76. max_lora_rank: int = 16
  77. enable_prompt_adapter: bool = False
  78. max_prompt_adapters: int = 1
  79. max_prompt_adapter_token: int = 0
  80. fully_sharded_loras: bool = False
  81. lora_extra_vocab_size: int = 256
  82. long_lora_scaling_factors: Optional[Tuple[float]] = None
  83. lora_dtype: str = "auto"
  84. max_cpu_loras: Optional[int] = None
  85. device: str = "auto"
  86. ray_workers_use_nsight: bool = False
  87. num_gpu_blocks_override: Optional[int] = None
  88. num_lookahead_slots: int = 0
  89. model_loader_extra_config: Optional[dict] = None
  90. preemption_mode: Optional[str] = None
  91. # Scheduler config
  92. scheduler_delay_factor: float = 0.0
  93. enable_chunked_prefill: bool = False
  94. guided_decoding_backend: str = 'outlines'
  95. # Speculative decoding config
  96. speculative_model: Optional[str] = None
  97. speculative_draft_tensor_parallel_size: Optional[int] = None
  98. num_speculative_tokens: Optional[int] = None
  99. speculative_max_model_len: Optional[int] = None
  100. speculative_disable_by_batch_size: Optional[int] = None
  101. ngram_prompt_lookup_max: Optional[int] = None
  102. ngram_prompt_lookup_min: Optional[int] = None
  103. spec_decoding_acceptance_method: str = 'rejection_sampler'
  104. typical_acceptance_sampler_posterior_threshold: Optional[float] = None
  105. typical_acceptance_sampler_posterior_alpha: Optional[float] = None
  106. disable_logprobs_during_spec_decoding: Optional[bool] = None
  107. qlora_adapter_name_or_path: Optional[str] = None
  108. def __post_init__(self):
  109. if self.tokenizer is None:
  110. self.tokenizer = self.model
  111. if is_cpu():
  112. self.distributed_executor_backend = None
  113. @staticmethod
  114. def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
  115. """Shared CLI arguments for the Aphrodite engine."""
  116. # NOTE: If you update any of the arguments below, please also
  117. # make sure to update docs/source/models/engine_args.rst
  118. # Model arguments
  119. parser.add_argument(
  120. "--model",
  121. type=str,
  122. default="EleutherAI/pythia-70m-deduped",
  123. help="name or path of the huggingface model to use",
  124. )
  125. parser.add_argument(
  126. "--tokenizer",
  127. type=str,
  128. default=EngineArgs.tokenizer,
  129. help="name or path of the huggingface tokenizer to use",
  130. )
  131. parser.add_argument(
  132. "--skip-tokenizer-init",
  133. action="store_true",
  134. help="Skip initialization of tokenizer and detokenizer")
  135. parser.add_argument(
  136. "--revision",
  137. type=str,
  138. default=None,
  139. help="the specific model version to use. It can be a branch "
  140. "name, a tag name, or a commit id. If unspecified, will use "
  141. "the default version.",
  142. )
  143. parser.add_argument(
  144. "--code-revision",
  145. type=str,
  146. default=None,
  147. help="the specific revision to use for the model code on "
  148. "Hugging Face Hub. It can be a branch name, a tag name, or a "
  149. "commit id. If unspecified, will use the default version.",
  150. )
  151. parser.add_argument(
  152. "--tokenizer-revision",
  153. type=str,
  154. default=None,
  155. help="the specific tokenizer version to use. It can be a branch "
  156. "name, a tag name, or a commit id. If unspecified, will use "
  157. "the default version.",
  158. )
  159. parser.add_argument(
  160. "--tokenizer-mode",
  161. type=str,
  162. default=EngineArgs.tokenizer_mode,
  163. choices=["auto", "slow"],
  164. help='tokenizer mode. "auto" will use the fast '
  165. 'tokenizer if available, and "slow" will '
  166. "always use the slow tokenizer.",
  167. )
  168. parser.add_argument(
  169. "--trust-remote-code",
  170. action="store_true",
  171. help="trust remote code from huggingface",
  172. )
  173. parser.add_argument(
  174. "--download-dir",
  175. type=str,
  176. default=EngineArgs.download_dir,
  177. help="directory to download and load the weights, "
  178. "default to the default cache dir of "
  179. "huggingface",
  180. )
  181. parser.add_argument(
  182. '--load-format',
  183. type=str,
  184. default=EngineArgs.load_format,
  185. choices=[
  186. 'auto',
  187. 'pt',
  188. 'safetensors',
  189. 'npcache',
  190. 'dummy',
  191. 'tensorizer',
  192. 'sharded_state',
  193. 'bitsandbytes',
  194. ],
  195. help='The format of the model weights to load.\n\n'
  196. '* "auto" will try to load the weights in the safetensors format '
  197. 'and fall back to the pytorch bin format if safetensors format '
  198. 'is not available.\n'
  199. '* "pt" will load the weights in the pytorch bin format.\n'
  200. '* "safetensors" will load the weights in the safetensors format.\n'
  201. '* "npcache" will load the weights in pytorch format and store '
  202. 'a numpy cache to speed up the loading.\n'
  203. '* "dummy" will initialize the weights with random values, '
  204. 'which is mainly for profiling.\n'
  205. '* "tensorizer" will load the weights using tensorizer from '
  206. 'CoreWeave. See the Tensorize Aphrodite Model script in the '
  207. 'Examples section for more information.\n'
  208. '* "bitsandbytes" will load the weights using bitsandbytes '
  209. 'quantization.\n')
  210. parser.add_argument(
  211. '--dtype',
  212. type=str,
  213. default=EngineArgs.dtype,
  214. choices=[
  215. 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
  216. ],
  217. help='Data type for model weights and activations.\n\n'
  218. '* "auto" will use FP16 precision for FP32 and FP16 models, and '
  219. 'BF16 precision for BF16 models.\n'
  220. '* "half" for FP16. Recommended for AWQ quantization.\n'
  221. '* "float16" is the same as "half".\n'
  222. '* "bfloat16" for a balance between precision and range.\n'
  223. '* "float" is shorthand for FP32 precision.\n'
  224. '* "float32" for FP32 precision.')
  225. parser.add_argument(
  226. '--kv-cache-dtype',
  227. type=str,
  228. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  229. default=EngineArgs.kv_cache_dtype,
  230. help='Data type for kv cache storage. If "auto", will use model '
  231. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  232. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  233. parser.add_argument(
  234. '--quantization-param-path',
  235. type=str,
  236. default=None,
  237. help='Path to the JSON file containing the KV cache '
  238. 'scaling factors. This should generally be supplied, when '
  239. 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
  240. 'default to 1.0, which may cause accuracy issues. '
  241. 'FP8_E5M2 (without scaling) is only supported on cuda version'
  242. 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
  243. 'supported for common inference criteria. ')
  244. parser.add_argument(
  245. "--max-model-len",
  246. type=int,
  247. default=EngineArgs.max_model_len,
  248. help="model context length. If unspecified, "
  249. "will be automatically derived from the model.",
  250. )
  251. parser.add_argument(
  252. '--guided-decoding-backend',
  253. type=str,
  254. default='outlines',
  255. choices=['outlines', 'lm-format-enforcer'],
  256. help='Which engine will be used for guided decoding'
  257. ' (JSON schema / regex etc) by default. Currently support '
  258. 'https://github.com/outlines-dev/outlines and '
  259. 'https://github.com/noamgat/lm-format-enforcer.'
  260. ' Can be overridden per request via guided_decoding_backend'
  261. ' parameter.')
  262. # Parallel arguments
  263. parser.add_argument(
  264. '--distributed-executor-backend',
  265. choices=['ray', 'mp'],
  266. default=EngineArgs.distributed_executor_backend,
  267. help='Backend to use for distributed serving. When more than 1 GPU '
  268. 'is used, will be automatically set to "ray" if installed '
  269. 'or "mp" (multiprocessing) otherwise.')
  270. parser.add_argument(
  271. '--worker-use-ray',
  272. action='store_true',
  273. help='Deprecated, use --distributed-executor-backend=ray.')
  274. parser.add_argument(
  275. "--pipeline-parallel-size",
  276. "-pp",
  277. type=int,
  278. default=EngineArgs.pipeline_parallel_size,
  279. help="number of pipeline stages. Currently not supported.")
  280. parser.add_argument(
  281. "--tensor-parallel-size",
  282. "-tp",
  283. type=int,
  284. default=EngineArgs.tensor_parallel_size,
  285. help="number of tensor parallel replicas, i.e. the number of GPUs "
  286. "to use.")
  287. parser.add_argument(
  288. "--max-parallel-loading-workers",
  289. type=int,
  290. default=EngineArgs.max_parallel_loading_workers,
  291. help="load model sequentially in multiple batches, "
  292. "to avoid RAM OOM when using tensor "
  293. "parallel and large models",
  294. )
  295. parser.add_argument(
  296. "--ray-workers-use-nsight",
  297. action="store_true",
  298. help="If specified, use nsight to profile ray workers",
  299. )
  300. # KV cache arguments
  301. parser.add_argument(
  302. "--block-size",
  303. type=int,
  304. default=EngineArgs.block_size,
  305. choices=[8, 16, 32],
  306. help="token block size",
  307. )
  308. parser.add_argument(
  309. "--enable-prefix-caching",
  310. "--context-shift",
  311. action="store_true",
  312. help="Enable automatic prefix caching.",
  313. )
  314. parser.add_argument('--disable-sliding-window',
  315. action='store_true',
  316. help='Disables sliding window, '
  317. 'capping to sliding window size')
  318. parser.add_argument("--use-v2-block-manager",
  319. action="store_true",
  320. help="Use the v2 block manager.")
  321. parser.add_argument(
  322. "--num-lookahead-slots",
  323. type=int,
  324. default=EngineArgs.num_lookahead_slots,
  325. help="Experimental scheduling config necessary for "
  326. "speculative decoding. This will be replaced by "
  327. "speculative decoding config in the future; it is "
  328. "present for testing purposes until then.")
  329. parser.add_argument("--seed",
  330. type=int,
  331. default=EngineArgs.seed,
  332. help="random seed")
  333. parser.add_argument(
  334. "--swap-space",
  335. type=int,
  336. default=EngineArgs.swap_space,
  337. help="CPU swap space size (GiB) per GPU",
  338. )
  339. parser.add_argument(
  340. '--cpu-offload-gb',
  341. type=float,
  342. default=0,
  343. help='The space in GiB to offload to CPU, per GPU. '
  344. 'Default is 0, which means no offloading. Intuitively, '
  345. 'this argument can be seen as a virtual way to increase '
  346. 'the GPU memory size. For example, if you have one 24 GB '
  347. 'GPU and set this to 10, virtually you can think of it as '
  348. 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
  349. 'which requires at least 26GB GPU memory. Note that this '
  350. 'requires fast CPU-GPU interconnect, as part of the model is'
  351. 'loaded from CPU memory to GPU memory on the fly in each '
  352. 'model forward pass.')
  353. parser.add_argument(
  354. "--gpu-memory-utilization",
  355. "-gmu",
  356. type=float,
  357. default=EngineArgs.gpu_memory_utilization,
  358. help="the fraction of GPU memory to be used for "
  359. "the model executor, which can range from 0 to 1."
  360. "If unspecified, will use the default value of 0.9.",
  361. )
  362. parser.add_argument(
  363. "--num-gpu-blocks-override",
  364. type=int,
  365. default=None,
  366. help="If specified, ignore GPU profiling result and use this "
  367. "number of GPU blocks. Used for testing preemption.")
  368. parser.add_argument(
  369. "--max-num-batched-tokens",
  370. type=int,
  371. default=EngineArgs.max_num_batched_tokens,
  372. help="maximum number of batched tokens per "
  373. "iteration",
  374. )
  375. parser.add_argument(
  376. "--max-num-seqs",
  377. type=int,
  378. default=EngineArgs.max_num_seqs,
  379. help="maximum number of sequences per iteration",
  380. )
  381. parser.add_argument(
  382. "--max-logprobs",
  383. type=int,
  384. default=EngineArgs.max_logprobs,
  385. help="maximum number of log probabilities to "
  386. "return.",
  387. )
  388. parser.add_argument(
  389. "--disable-log-stats",
  390. action="store_true",
  391. help="disable logging statistics",
  392. )
  393. # Quantization settings.
  394. parser.add_argument(
  395. "--quantization",
  396. "-q",
  397. type=str,
  398. choices=[*QUANTIZATION_METHODS, None],
  399. default=EngineArgs.quantization,
  400. help="Method used to quantize the weights. If "
  401. "None, we first check the `quantization_config` "
  402. "attribute in the model config file. If that is "
  403. "None, we assume the model weights are not "
  404. "quantized and use `dtype` to determine the data "
  405. "type of the weights.",
  406. )
  407. parser.add_argument(
  408. "--load-in-4bit",
  409. action="store_true",
  410. help="Load the FP16 model in 4-bit format. Also "
  411. "works with AWQ models. Throughput at 2.5x of "
  412. "FP16.",
  413. )
  414. parser.add_argument(
  415. "--load-in-8bit",
  416. action="store_true",
  417. help="Load the FP16 model in 8-bit format. "
  418. "Throughput at 0.3x of FP16.",
  419. )
  420. parser.add_argument(
  421. "--load-in-smooth",
  422. action="store_true",
  423. help="Load the FP16 model in smoothquant "
  424. "8bit format. Throughput at 0.7x of FP16. ",
  425. )
  426. parser.add_argument(
  427. "--deepspeed-fp-bits",
  428. type=int,
  429. default=None,
  430. help="Number of floating bits to use for the deepseed "
  431. "quantization. Supported bits are: 4, 6, 8, 12. ")
  432. parser.add_argument('--rope-scaling',
  433. default=None,
  434. type=json.loads,
  435. help='RoPE scaling configuration in JSON format. '
  436. 'For example, {"type":"dynamic","factor":2.0}')
  437. parser.add_argument('--rope-theta',
  438. default=None,
  439. type=float,
  440. help='RoPE theta. Use with `rope_scaling`. In '
  441. 'some cases, changing the RoPE theta improves the '
  442. 'performance of the scaled model.')
  443. parser.add_argument(
  444. "--enforce-eager",
  445. type=lambda x: (str(x).lower() == 'true'),
  446. default=EngineArgs.enforce_eager,
  447. help="Always use eager-mode PyTorch. If False, "
  448. "will use eager mode and CUDA graph in hybrid "
  449. "for maximal performance and flexibility.",
  450. )
  451. parser.add_argument("--max-context-len-to-capture",
  452. type=int,
  453. default=EngineArgs.max_context_len_to_capture,
  454. help="Maximum context length covered by CUDA "
  455. "graphs. When a sequence has context length "
  456. "larger than this, we fall back to eager mode. "
  457. "(DEPRECATED. Use --max-seq_len-to-capture instead"
  458. ")")
  459. parser.add_argument("--max-seq_len-to-capture",
  460. type=int,
  461. default=EngineArgs.max_seq_len_to_capture,
  462. help="Maximum sequence length covered by CUDA "
  463. "graphs. When a sequence has context length "
  464. "larger than this, we fall back to eager mode.")
  465. parser.add_argument(
  466. "--disable-custom-all-reduce",
  467. action="store_true",
  468. default=EngineArgs.disable_custom_all_reduce,
  469. help="See ParallelConfig",
  470. )
  471. parser.add_argument("--tokenizer-pool-size",
  472. type=int,
  473. default=EngineArgs.tokenizer_pool_size,
  474. help="Size of tokenizer pool to use for "
  475. "asynchronous tokenization. If 0, will "
  476. "use synchronous tokenization.")
  477. parser.add_argument("--tokenizer-pool-type",
  478. type=str,
  479. default=EngineArgs.tokenizer_pool_type,
  480. help="The type of tokenizer pool to use for "
  481. "asynchronous tokenization. Ignored if "
  482. "tokenizer_pool_size is 0.")
  483. parser.add_argument("--tokenizer-pool-extra-config",
  484. type=str,
  485. default=EngineArgs.tokenizer_pool_extra_config,
  486. help="Extra config for tokenizer pool. "
  487. "This should be a JSON string that will be "
  488. "parsed into a dictionary. Ignored if "
  489. "tokenizer_pool_size is 0.")
  490. parser.add_argument(
  491. '--preemption-mode',
  492. type=str,
  493. default=None,
  494. help='If \'recompute\', the engine performs preemption by block '
  495. 'swapping; If \'swap\', the engine performs preemption by block '
  496. 'swapping.')
  497. # LoRA related configs
  498. parser.add_argument(
  499. "--enable-lora",
  500. action="store_true",
  501. help="If True, enable handling of LoRA adapters.",
  502. )
  503. parser.add_argument(
  504. "--max-loras",
  505. type=int,
  506. default=EngineArgs.max_loras,
  507. help="Max number of LoRAs in a single batch.",
  508. )
  509. parser.add_argument(
  510. "--max-lora-rank",
  511. type=int,
  512. default=EngineArgs.max_lora_rank,
  513. help="Max LoRA rank.",
  514. )
  515. parser.add_argument(
  516. "--lora-extra-vocab-size",
  517. type=int,
  518. default=EngineArgs.lora_extra_vocab_size,
  519. help=("Maximum size of extra vocabulary that can be "
  520. "present in a LoRA adapter (added to the base "
  521. "model vocabulary)."),
  522. )
  523. parser.add_argument(
  524. "--lora-dtype",
  525. type=str,
  526. default=EngineArgs.lora_dtype,
  527. choices=["auto", "float16", "bfloat16", "float32"],
  528. help=("Data type for LoRA. If auto, will default to "
  529. "base model dtype."),
  530. )
  531. parser.add_argument(
  532. "--long-lora-scaling-factors",
  533. type=str,
  534. default=EngineArgs.long_lora_scaling_factors,
  535. help=("Specify multiple scaling factors (which can "
  536. "be different from base model scaling factor "
  537. "- see eg. Long LoRA) to allow for multiple "
  538. "LoRA adapters trained with those scaling "
  539. "factors to be used at the same time. If not "
  540. "specified, only adapters trained with the "
  541. "base model scaling factor are allowed."))
  542. parser.add_argument(
  543. "--max-cpu-loras",
  544. type=int,
  545. default=EngineArgs.max_cpu_loras,
  546. help=("Maximum number of LoRAs to store in CPU memory. "
  547. "Must be >= than max_num_seqs. "
  548. "Defaults to max_num_seqs."),
  549. )
  550. parser.add_argument(
  551. "--fully-sharded-loras",
  552. action='store_true',
  553. help=("By default, only half of the LoRA computation is sharded "
  554. "with tensor parallelism. Enabling this will use the fully "
  555. "sharded layers. At high sequence length, max rank or "
  556. "tensor parallel size, this is likely faster."))
  557. parser.add_argument('--enable-prompt-adapter',
  558. action='store_true',
  559. help='If True, enable handling of PromptAdapters.')
  560. parser.add_argument('--max-prompt-adapters',
  561. type=int,
  562. default=EngineArgs.max_prompt_adapters,
  563. help='Max number of PromptAdapters in a batch.')
  564. parser.add_argument('--max-prompt-adapter-token',
  565. type=int,
  566. default=EngineArgs.max_prompt_adapter_token,
  567. help='Max number of PromptAdapters tokens')
  568. parser.add_argument(
  569. "--device",
  570. type=str,
  571. default=EngineArgs.device,
  572. choices=[
  573. "auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"
  574. ],
  575. help=("Device to use for model execution."),
  576. )
  577. parser.add_argument(
  578. "--scheduler-delay-factor",
  579. "-sdf",
  580. type=float,
  581. default=EngineArgs.scheduler_delay_factor,
  582. help="Apply a delay (of delay factor multiplied by previous "
  583. "prompt latency) before scheduling next prompt.")
  584. parser.add_argument(
  585. "--enable-chunked-prefill",
  586. action="store_true",
  587. help="If True, the prefill requests can be chunked based on the "
  588. "max_num_batched_tokens.")
  589. parser.add_argument(
  590. "--speculative-model",
  591. type=str,
  592. default=EngineArgs.speculative_model,
  593. help=
  594. "The name of the draft model to be used in speculative decoding.")
  595. parser.add_argument(
  596. "--speculative-draft-tensor-parallel-size",
  597. "-spec-draft-tp",
  598. type=int,
  599. default=EngineArgs.speculative_draft_tensor_parallel_size,
  600. help="Number of tensor parallel replicas for "
  601. "the draft model in speculative decoding.")
  602. parser.add_argument(
  603. "--num-speculative-tokens",
  604. type=int,
  605. default=EngineArgs.num_speculative_tokens,
  606. help="The number of speculative tokens to sample from "
  607. "the draft model in speculative decoding")
  608. parser.add_argument(
  609. "--speculative-max-model-len",
  610. type=str,
  611. default=EngineArgs.speculative_max_model_len,
  612. help="The maximum sequence length supported by the "
  613. "draft model. Sequences over this length will skip "
  614. "speculation.")
  615. parser.add_argument(
  616. "--speculative-disable-by-batch-size",
  617. type=int,
  618. default=EngineArgs.speculative_disable_by_batch_size,
  619. help="Disable speculative decoding for new incoming requests "
  620. "if the number of enqueue requests is larger than this value.")
  621. parser.add_argument(
  622. "--ngram-prompt-lookup-max",
  623. type=int,
  624. default=EngineArgs.ngram_prompt_lookup_max,
  625. help="Max size of window for ngram prompt lookup in speculative "
  626. "decoding.")
  627. parser.add_argument(
  628. "--ngram-prompt-lookup-min",
  629. type=int,
  630. default=EngineArgs.ngram_prompt_lookup_min,
  631. help="Min size of window for ngram prompt lookup in speculative "
  632. "decoding.")
  633. parser.add_argument(
  634. '--spec-decoding-acceptance-method',
  635. type=str,
  636. default=EngineArgs.spec_decoding_acceptance_method,
  637. choices=['rejection_sampler', 'typical_acceptance_sampler'],
  638. help='Specify the acceptance method to use during draft token '
  639. 'verification in speculative decoding. Two types of acceptance '
  640. 'routines are supported: '
  641. '1) RejectionSampler which does not allow changing the '
  642. 'acceptance rate of draft tokens, '
  643. '2) TypicalAcceptanceSampler which is configurable, allowing for '
  644. 'a higher acceptance rate at the cost of lower quality, '
  645. 'and vice versa.')
  646. parser.add_argument(
  647. '--typical-acceptance-sampler-posterior-threshold',
  648. type=float,
  649. default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
  650. help='Set the lower bound threshold for the posterior '
  651. 'probability of a token to be accepted. This threshold is '
  652. 'used by the TypicalAcceptanceSampler to make sampling decisions '
  653. 'during speculative decoding. Defaults to 0.09')
  654. parser.add_argument(
  655. '--typical-acceptance-sampler-posterior-alpha',
  656. type=float,
  657. default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
  658. help='A scaling factor for the entropy-based threshold for token '
  659. 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
  660. 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
  661. 'i.e. 0.3')
  662. parser.add_argument(
  663. '--disable-logprobs-during-spec-decoding',
  664. type=bool,
  665. default=EngineArgs.disable_logprobs_during_spec_decoding,
  666. help='If set to True, token log probabilities are not returned '
  667. 'during speculative decoding. If set to False, log probabilities '
  668. 'are returned according to the settings in SamplingParams. If '
  669. 'not specified, it defaults to True. Disabling log probabilities '
  670. 'during speculative decoding reduces latency by skipping logprob '
  671. 'calculation in proposal sampling, target sampling, and after '
  672. 'accepted tokens are determined.')
  673. parser.add_argument("--model-loader-extra-config",
  674. type=str,
  675. default=EngineArgs.model_loader_extra_config,
  676. help="Extra config for model loader. "
  677. "This will be passed to the model loader "
  678. "corresponding to the chosen load_format. "
  679. "This should be a JSON string that will be "
  680. "parsed into a dictionary.")
  681. parser.add_argument(
  682. "--served-model-name",
  683. nargs="+",
  684. type=str,
  685. default=None,
  686. help="The model name(s) used in the API. If multiple "
  687. "names are provided, the server will respond to any "
  688. "of the provided names. The model name in the model "
  689. "field of a response will be the first name in this "
  690. "list. If not specified, the model name will be the "
  691. "same as the `--model` argument. Noted that this name(s)"
  692. "will also be used in `model_name` tag content of "
  693. "prometheus metrics, if multiple names provided, metrics"
  694. "tag will take the first one.")
  695. parser.add_argument("--qlora-adapter-name-or-path",
  696. type=str,
  697. default=None,
  698. help="Name or path of the LoRA adapter to use.")
  699. return parser
  700. @classmethod
  701. def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
  702. # Get the list of attributes of this dataclass.
  703. attrs = [attr.name for attr in dataclasses.fields(cls)]
  704. # Set the attributes from the parsed arguments.
  705. engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
  706. return engine_args
  707. def create_engine_config(self, ) -> EngineConfig:
  708. # bitsandbytes quantization needs a specific model loader
  709. # so we make sure the quant method and the load format are consistent
  710. if (self.quantization == "bitsandbytes" or
  711. self.qlora_adapter_name_or_path is not None) and \
  712. self.load_format != "bitsandbytes":
  713. raise ValueError(
  714. "BitsAndBytes quantization and QLoRA adapter only support "
  715. f"'bitsandbytes' load format, but got {self.load_format}")
  716. if (self.load_format == "bitsandbytes" or
  717. self.qlora_adapter_name_or_path is not None) and \
  718. self.quantization != "bitsandbytes":
  719. raise ValueError(
  720. "BitsAndBytes load format and QLoRA adapter only support "
  721. f"'bitsandbytes' quantization, but got {self.quantization}")
  722. assert self.cpu_offload_gb >= 0, (
  723. "CPU offload space must be non-negative"
  724. f", but got {self.cpu_offload_gb}")
  725. multimodal_config = MultiModalConfig()
  726. device_config = DeviceConfig(device=self.device)
  727. model_config = ModelConfig(
  728. model=self.model,
  729. tokenizer=self.tokenizer,
  730. tokenizer_mode=self.tokenizer_mode,
  731. trust_remote_code=self.trust_remote_code,
  732. dtype=self.dtype,
  733. seed=self.seed,
  734. revision=self.revision,
  735. code_revision=self.code_revision,
  736. rope_scaling=self.rope_scaling,
  737. rope_theta=self.rope_theta,
  738. tokenizer_revision=self.tokenizer_revision,
  739. max_model_len=self.max_model_len,
  740. quantization=self.quantization,
  741. load_in_4bit=self.load_in_4bit,
  742. load_in_8bit=self.load_in_8bit,
  743. load_in_smooth=self.load_in_smooth,
  744. deepspeed_fp_bits=self.deepspeed_fp_bits,
  745. quantization_param_path=self.quantization_param_path,
  746. enforce_eager=self.enforce_eager,
  747. max_context_len_to_capture=self.max_context_len_to_capture,
  748. max_seq_len_to_capture=self.max_seq_len_to_capture,
  749. max_logprobs=self.max_logprobs,
  750. disable_sliding_window=self.disable_sliding_window,
  751. skip_tokenizer_init=self.skip_tokenizer_init,
  752. served_model_name=self.served_model_name,
  753. multimodal_config=multimodal_config,
  754. )
  755. cache_config = CacheConfig(
  756. block_size=self.block_size,
  757. gpu_memory_utilization=self.gpu_memory_utilization,
  758. swap_space=self.swap_space,
  759. cache_dtype=self.kv_cache_dtype,
  760. num_gpu_blocks_override=self.num_gpu_blocks_override,
  761. sliding_window=model_config.get_sliding_window(),
  762. enable_prefix_caching=self.enable_prefix_caching,
  763. cpu_offload_gb=self.cpu_offload_gb,
  764. )
  765. parallel_config = ParallelConfig(
  766. pipeline_parallel_size=self.pipeline_parallel_size,
  767. tensor_parallel_size=self.tensor_parallel_size,
  768. worker_use_ray=self.worker_use_ray,
  769. max_parallel_loading_workers=self.max_parallel_loading_workers,
  770. disable_custom_all_reduce=self.disable_custom_all_reduce,
  771. tokenizer_pool_config=TokenizerPoolConfig.create_config(
  772. tokenizer_pool_size=self.tokenizer_pool_size,
  773. tokenizer_pool_type=self.tokenizer_pool_type,
  774. tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
  775. ),
  776. ray_workers_use_nsight=self.ray_workers_use_nsight,
  777. distributed_executor_backend=self.distributed_executor_backend)
  778. speculative_config = SpeculativeConfig.maybe_create_spec_config(
  779. target_model_config=model_config,
  780. target_parallel_config=parallel_config,
  781. target_dtype=self.dtype,
  782. speculative_model=self.speculative_model,
  783. speculative_draft_tensor_parallel_size=self.
  784. speculative_draft_tensor_parallel_size,
  785. num_speculative_tokens=self.num_speculative_tokens,
  786. speculative_disable_by_batch_size=self.
  787. speculative_disable_by_batch_size,
  788. speculative_max_model_len=self.speculative_max_model_len,
  789. enable_chunked_prefill=self.enable_chunked_prefill,
  790. use_v2_block_manager=self.use_v2_block_manager,
  791. ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
  792. ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
  793. draft_token_acceptance_method=\
  794. self.spec_decoding_acceptance_method,
  795. typical_acceptance_sampler_posterior_threshold=self.
  796. typical_acceptance_sampler_posterior_threshold,
  797. typical_acceptance_sampler_posterior_alpha=self.
  798. typical_acceptance_sampler_posterior_alpha,
  799. disable_logprobs=self.disable_logprobs_during_spec_decoding,
  800. )
  801. scheduler_config = SchedulerConfig(
  802. max_num_batched_tokens=self.max_num_batched_tokens,
  803. max_num_seqs=self.max_num_seqs,
  804. max_model_len=model_config.max_model_len,
  805. use_v2_block_manager=self.use_v2_block_manager,
  806. num_lookahead_slots=(self.num_lookahead_slots
  807. if speculative_config is None else
  808. speculative_config.num_lookahead_slots),
  809. delay_factor=self.scheduler_delay_factor,
  810. enable_chunked_prefill=self.enable_chunked_prefill,
  811. embedding_mode=model_config.embedding_mode,
  812. preemption_mode=self.preemption_mode,
  813. )
  814. lora_config = LoRAConfig(
  815. max_lora_rank=self.max_lora_rank,
  816. max_loras=self.max_loras,
  817. fully_sharded_loras=self.fully_sharded_loras,
  818. lora_extra_vocab_size=self.lora_extra_vocab_size,
  819. long_lora_scaling_factors=self.long_lora_scaling_factors,
  820. lora_dtype=self.lora_dtype,
  821. max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
  822. and self.max_cpu_loras > 0 else None) if self.enable_lora else None
  823. if self.qlora_adapter_name_or_path is not None and \
  824. self.qlora_adapter_name_or_path != "":
  825. if self.model_loader_extra_config is None:
  826. self.model_loader_extra_config = {}
  827. self.model_loader_extra_config[
  828. "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
  829. load_config = LoadConfig(
  830. load_format=self.load_format,
  831. download_dir=self.download_dir,
  832. model_loader_extra_config=self.model_loader_extra_config,
  833. )
  834. prompt_adapter_config = PromptAdapterConfig(
  835. max_prompt_adapters=self.max_prompt_adapters,
  836. max_prompt_adapter_token=self.max_prompt_adapter_token) \
  837. if self.enable_prompt_adapter else None
  838. decoding_config = DecodingConfig(
  839. guided_decoding_backend=self.guided_decoding_backend)
  840. if (model_config.get_sliding_window() is not None
  841. and scheduler_config.chunked_prefill_enabled
  842. and not scheduler_config.use_v2_block_manager):
  843. raise ValueError(
  844. "Chunked prefill is not supported with sliding window. "
  845. "Set --disable-sliding-window to disable sliding window.")
  846. return EngineConfig(model_config=model_config,
  847. cache_config=cache_config,
  848. parallel_config=parallel_config,
  849. scheduler_config=scheduler_config,
  850. device_config=device_config,
  851. lora_config=lora_config,
  852. multimodal_config=multimodal_config,
  853. speculative_config=speculative_config,
  854. load_config=load_config,
  855. decoding_config=decoding_config,
  856. prompt_adapter_config=prompt_adapter_config)
  857. @dataclass
  858. class AsyncEngineArgs(EngineArgs):
  859. """Arguments for asynchronous Aphrodite engine."""
  860. engine_use_ray: bool = False
  861. disable_log_requests: bool = False
  862. max_log_len: int = 0
  863. uvloop: bool = False
  864. @staticmethod
  865. def add_cli_args(parser: FlexibleArgumentParser,
  866. async_args_only: bool = False) -> FlexibleArgumentParser:
  867. if not async_args_only:
  868. parser = EngineArgs.add_cli_args(parser)
  869. parser.add_argument('--engine-use-ray',
  870. action='store_true',
  871. help='Use Ray to start the LLM engine in a '
  872. 'separate process as the server process.')
  873. parser.add_argument('--disable-log-requests',
  874. action='store_true',
  875. help='Disable logging requests.')
  876. parser.add_argument('--max-log-len',
  877. type=int,
  878. default=0,
  879. help='Max number of prompt characters or prompt '
  880. 'ID numbers being printed in log.'
  881. '\n\nDefault: Unlimited')
  882. parser.add_argument(
  883. "--uvloop",
  884. action="store_true",
  885. help="Use the Uvloop asyncio event loop to possibly increase "
  886. "performance")
  887. return parser