args_tools.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. import argparse
  2. import dataclasses
  3. import json
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
  6. from aphrodite.common.config import (
  7. CacheConfig, ControlVectorConfig, DecodingConfig, DeviceConfig,
  8. EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig,
  10. TokenizerPoolConfig)
  11. from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
  12. from aphrodite.executor.executor_base import ExecutorBase
  13. from aphrodite.quantization import QUANTIZATION_METHODS
  14. if TYPE_CHECKING:
  15. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import \
  16. BaseTokenizerGroup # noqa: E501
  17. @dataclass
  18. class EngineArgs:
  19. """Arguments for Aphrodite engine."""
  20. model: str
  21. served_model_name: Optional[Union[List[str]]] = None
  22. tokenizer: Optional[str] = None
  23. skip_tokenizer_init: bool = False
  24. tokenizer_mode: str = "auto"
  25. trust_remote_code: bool = False
  26. download_dir: Optional[str] = None
  27. load_format: str = "auto"
  28. dtype: str = "auto"
  29. kv_cache_dtype: str = "auto"
  30. quantization_param_path: Optional[str] = None
  31. seed: int = 0
  32. max_model_len: Optional[int] = None
  33. worker_use_ray: Optional[bool] = False
  34. # Note: Specifying a custom executor backend by passing a class
  35. # is intended for expert use only. The API may change without
  36. # notice.
  37. distributed_executor_backend: Optional[Union[str,
  38. Type[ExecutorBase]]] = None
  39. pipeline_parallel_size: int = 1
  40. tensor_parallel_size: int = 1
  41. max_parallel_loading_workers: Optional[int] = None
  42. block_size: int = 16
  43. enable_prefix_caching: bool = False
  44. disable_sliding_window: bool = False
  45. use_v2_block_manager: bool = False
  46. swap_space: int = 4 # GiB
  47. cpu_offload_gb: int = 0 # GiB
  48. gpu_memory_utilization: float = 0.90
  49. max_num_batched_tokens: Optional[int] = None
  50. max_num_seqs: int = 256
  51. max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
  52. disable_log_stats: bool = False
  53. revision: Optional[str] = None
  54. code_revision: Optional[str] = None
  55. rope_scaling: Optional[dict] = None
  56. rope_theta: Optional[float] = None
  57. tokenizer_revision: Optional[str] = None
  58. quantization: Optional[str] = None
  59. load_in_4bit: bool = False
  60. load_in_8bit: bool = False
  61. load_in_smooth: bool = False
  62. deepspeed_fp_bits: Optional[int] = None
  63. enforce_eager: bool = True
  64. max_context_len_to_capture: Optional[int] = None
  65. max_seq_len_to_capture: int = 8192
  66. disable_custom_all_reduce: bool = False
  67. tokenizer_pool_size: int = 0
  68. # Note: Specifying a tokenizer pool by passing a class
  69. # is intended for expert use only. The API may change without
  70. # notice.
  71. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
  72. tokenizer_pool_extra_config: Optional[dict] = None
  73. enable_lora: bool = False
  74. max_loras: int = 1
  75. max_lora_rank: int = 16
  76. enable_prompt_adapter: bool = False
  77. max_prompt_adapters: int = 1
  78. max_prompt_adapter_token: int = 0
  79. enable_control_vector: bool = False
  80. max_control_vectors: int = 1
  81. normalize_control_vector: bool = False
  82. fully_sharded_loras: bool = False
  83. lora_extra_vocab_size: int = 256
  84. long_lora_scaling_factors: Optional[Tuple[float]] = None
  85. lora_dtype: str = "auto"
  86. max_cpu_loras: Optional[int] = None
  87. device: str = "auto"
  88. ray_workers_use_nsight: bool = False
  89. num_gpu_blocks_override: Optional[int] = None
  90. num_lookahead_slots: int = 0
  91. model_loader_extra_config: Optional[dict] = None
  92. preemption_mode: Optional[str] = None
  93. # Scheduler config
  94. scheduler_delay_factor: float = 0.0
  95. enable_chunked_prefill: bool = False
  96. guided_decoding_backend: str = 'outlines'
  97. # Speculative decoding config
  98. speculative_model: Optional[str] = None
  99. speculative_draft_tensor_parallel_size: Optional[int] = None
  100. num_speculative_tokens: Optional[int] = None
  101. speculative_max_model_len: Optional[int] = None
  102. speculative_disable_by_batch_size: Optional[int] = None
  103. ngram_prompt_lookup_max: Optional[int] = None
  104. ngram_prompt_lookup_min: Optional[int] = None
  105. spec_decoding_acceptance_method: str = 'rejection_sampler'
  106. typical_acceptance_sampler_posterior_threshold: Optional[float] = None
  107. typical_acceptance_sampler_posterior_alpha: Optional[float] = None
  108. disable_logprobs_during_spec_decoding: Optional[bool] = None
  109. qlora_adapter_name_or_path: Optional[str] = None
  110. def __post_init__(self):
  111. if self.tokenizer is None:
  112. self.tokenizer = self.model
  113. if is_cpu():
  114. self.distributed_executor_backend = None
  115. @staticmethod
  116. def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
  117. """Shared CLI arguments for the Aphrodite engine."""
  118. # NOTE: If you update any of the arguments below, please also
  119. # make sure to update docs/source/models/engine_args.rst
  120. # Model arguments
  121. parser.add_argument(
  122. "--model",
  123. type=str,
  124. default="EleutherAI/pythia-70m-deduped",
  125. help="name or path of the huggingface model to use",
  126. )
  127. parser.add_argument(
  128. "--tokenizer",
  129. type=str,
  130. default=EngineArgs.tokenizer,
  131. help="name or path of the huggingface tokenizer to use",
  132. )
  133. parser.add_argument(
  134. "--skip-tokenizer-init",
  135. action="store_true",
  136. help="Skip initialization of tokenizer and detokenizer")
  137. parser.add_argument(
  138. "--revision",
  139. type=str,
  140. default=None,
  141. help="the specific model version to use. It can be a branch "
  142. "name, a tag name, or a commit id. If unspecified, will use "
  143. "the default version.",
  144. )
  145. parser.add_argument(
  146. "--code-revision",
  147. type=str,
  148. default=None,
  149. help="the specific revision to use for the model code on "
  150. "Hugging Face Hub. It can be a branch name, a tag name, or a "
  151. "commit id. If unspecified, will use the default version.",
  152. )
  153. parser.add_argument(
  154. "--tokenizer-revision",
  155. type=str,
  156. default=None,
  157. help="the specific tokenizer version to use. It can be a branch "
  158. "name, a tag name, or a commit id. If unspecified, will use "
  159. "the default version.",
  160. )
  161. parser.add_argument(
  162. "--tokenizer-mode",
  163. type=str,
  164. default=EngineArgs.tokenizer_mode,
  165. choices=["auto", "slow"],
  166. help='tokenizer mode. "auto" will use the fast '
  167. 'tokenizer if available, and "slow" will '
  168. "always use the slow tokenizer.",
  169. )
  170. parser.add_argument(
  171. "--trust-remote-code",
  172. action="store_true",
  173. help="trust remote code from huggingface",
  174. )
  175. parser.add_argument(
  176. "--download-dir",
  177. type=str,
  178. default=EngineArgs.download_dir,
  179. help="directory to download and load the weights, "
  180. "default to the default cache dir of "
  181. "huggingface",
  182. )
  183. parser.add_argument(
  184. '--load-format',
  185. type=str,
  186. default=EngineArgs.load_format,
  187. choices=[
  188. 'auto',
  189. 'pt',
  190. 'safetensors',
  191. 'npcache',
  192. 'dummy',
  193. 'tensorizer',
  194. 'sharded_state',
  195. 'bitsandbytes',
  196. ],
  197. help='The format of the model weights to load.\n\n'
  198. '* "auto" will try to load the weights in the safetensors format '
  199. 'and fall back to the pytorch bin format if safetensors format '
  200. 'is not available.\n'
  201. '* "pt" will load the weights in the pytorch bin format.\n'
  202. '* "safetensors" will load the weights in the safetensors format.\n'
  203. '* "npcache" will load the weights in pytorch format and store '
  204. 'a numpy cache to speed up the loading.\n'
  205. '* "dummy" will initialize the weights with random values, '
  206. 'which is mainly for profiling.\n'
  207. '* "tensorizer" will load the weights using tensorizer from '
  208. 'CoreWeave. See the Tensorize Aphrodite Model script in the '
  209. 'Examples section for more information.\n'
  210. '* "bitsandbytes" will load the weights using bitsandbytes '
  211. 'quantization.\n')
  212. parser.add_argument(
  213. '--dtype',
  214. type=str,
  215. default=EngineArgs.dtype,
  216. choices=[
  217. 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
  218. ],
  219. help='Data type for model weights and activations.\n\n'
  220. '* "auto" will use FP16 precision for FP32 and FP16 models, and '
  221. 'BF16 precision for BF16 models.\n'
  222. '* "half" for FP16. Recommended for AWQ quantization.\n'
  223. '* "float16" is the same as "half".\n'
  224. '* "bfloat16" for a balance between precision and range.\n'
  225. '* "float" is shorthand for FP32 precision.\n'
  226. '* "float32" for FP32 precision.')
  227. parser.add_argument(
  228. '--kv-cache-dtype',
  229. type=str,
  230. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  231. default=EngineArgs.kv_cache_dtype,
  232. help='Data type for kv cache storage. If "auto", will use model '
  233. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  234. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  235. parser.add_argument(
  236. '--quantization-param-path',
  237. type=str,
  238. default=None,
  239. help='Path to the JSON file containing the KV cache '
  240. 'scaling factors. This should generally be supplied, when '
  241. 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
  242. 'default to 1.0, which may cause accuracy issues. '
  243. 'FP8_E5M2 (without scaling) is only supported on cuda version'
  244. 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
  245. 'supported for common inference criteria. ')
  246. parser.add_argument(
  247. "--max-model-len",
  248. type=int,
  249. default=EngineArgs.max_model_len,
  250. help="model context length. If unspecified, "
  251. "will be automatically derived from the model.",
  252. )
  253. parser.add_argument(
  254. '--guided-decoding-backend',
  255. type=str,
  256. default='outlines',
  257. choices=['outlines', 'lm-format-enforcer'],
  258. help='Which engine will be used for guided decoding'
  259. ' (JSON schema / regex etc) by default. Currently support '
  260. 'https://github.com/outlines-dev/outlines and '
  261. 'https://github.com/noamgat/lm-format-enforcer.'
  262. ' Can be overridden per request via guided_decoding_backend'
  263. ' parameter.')
  264. # Parallel arguments
  265. parser.add_argument(
  266. '--distributed-executor-backend',
  267. choices=['ray', 'mp'],
  268. default=EngineArgs.distributed_executor_backend,
  269. help='Backend to use for distributed serving. When more than 1 GPU '
  270. 'is used, will be automatically set to "ray" if installed '
  271. 'or "mp" (multiprocessing) otherwise.')
  272. parser.add_argument(
  273. '--worker-use-ray',
  274. action='store_true',
  275. help='Deprecated, use --distributed-executor-backend=ray.')
  276. parser.add_argument(
  277. "--pipeline-parallel-size",
  278. "-pp",
  279. type=int,
  280. default=EngineArgs.pipeline_parallel_size,
  281. help="number of pipeline stages. Currently not supported.")
  282. parser.add_argument(
  283. "--tensor-parallel-size",
  284. "-tp",
  285. type=int,
  286. default=EngineArgs.tensor_parallel_size,
  287. help="number of tensor parallel replicas, i.e. the number of GPUs "
  288. "to use.")
  289. parser.add_argument(
  290. "--max-parallel-loading-workers",
  291. type=int,
  292. default=EngineArgs.max_parallel_loading_workers,
  293. help="load model sequentially in multiple batches, "
  294. "to avoid RAM OOM when using tensor "
  295. "parallel and large models",
  296. )
  297. parser.add_argument(
  298. "--ray-workers-use-nsight",
  299. action="store_true",
  300. help="If specified, use nsight to profile ray workers",
  301. )
  302. # KV cache arguments
  303. parser.add_argument(
  304. "--block-size",
  305. type=int,
  306. default=EngineArgs.block_size,
  307. choices=[8, 16, 32],
  308. help="token block size",
  309. )
  310. parser.add_argument(
  311. "--enable-prefix-caching",
  312. "--context-shift",
  313. action="store_true",
  314. help="Enable automatic prefix caching.",
  315. )
  316. parser.add_argument('--disable-sliding-window',
  317. action='store_true',
  318. help='Disables sliding window, '
  319. 'capping to sliding window size')
  320. parser.add_argument("--use-v2-block-manager",
  321. action="store_true",
  322. help="Use the v2 block manager.")
  323. parser.add_argument(
  324. "--num-lookahead-slots",
  325. type=int,
  326. default=EngineArgs.num_lookahead_slots,
  327. help="Experimental scheduling config necessary for "
  328. "speculative decoding. This will be replaced by "
  329. "speculative decoding config in the future; it is "
  330. "present for testing purposes until then.")
  331. parser.add_argument("--seed",
  332. type=int,
  333. default=EngineArgs.seed,
  334. help="random seed")
  335. parser.add_argument(
  336. "--swap-space",
  337. type=int,
  338. default=EngineArgs.swap_space,
  339. help="CPU swap space size (GiB) per GPU",
  340. )
  341. parser.add_argument(
  342. '--cpu-offload-gb',
  343. type=float,
  344. default=0,
  345. help='The space in GiB to offload to CPU, per GPU. '
  346. 'Default is 0, which means no offloading. Intuitively, '
  347. 'this argument can be seen as a virtual way to increase '
  348. 'the GPU memory size. For example, if you have one 24 GB '
  349. 'GPU and set this to 10, virtually you can think of it as '
  350. 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
  351. 'which requires at least 26GB GPU memory. Note that this '
  352. 'requires fast CPU-GPU interconnect, as part of the model is'
  353. 'loaded from CPU memory to GPU memory on the fly in each '
  354. 'model forward pass.')
  355. parser.add_argument(
  356. "--gpu-memory-utilization",
  357. "-gmu",
  358. type=float,
  359. default=EngineArgs.gpu_memory_utilization,
  360. help="the fraction of GPU memory to be used for "
  361. "the model executor, which can range from 0 to 1."
  362. "If unspecified, will use the default value of 0.9.",
  363. )
  364. parser.add_argument(
  365. "--num-gpu-blocks-override",
  366. type=int,
  367. default=None,
  368. help="If specified, ignore GPU profiling result and use this "
  369. "number of GPU blocks. Used for testing preemption.")
  370. parser.add_argument(
  371. "--max-num-batched-tokens",
  372. type=int,
  373. default=EngineArgs.max_num_batched_tokens,
  374. help="maximum number of batched tokens per "
  375. "iteration",
  376. )
  377. parser.add_argument(
  378. "--max-num-seqs",
  379. type=int,
  380. default=EngineArgs.max_num_seqs,
  381. help="maximum number of sequences per iteration",
  382. )
  383. parser.add_argument(
  384. "--max-logprobs",
  385. type=int,
  386. default=EngineArgs.max_logprobs,
  387. help="maximum number of log probabilities to "
  388. "return.",
  389. )
  390. parser.add_argument(
  391. "--disable-log-stats",
  392. action="store_true",
  393. help="disable logging statistics",
  394. )
  395. # Quantization settings.
  396. parser.add_argument(
  397. "--quantization",
  398. "-q",
  399. type=str,
  400. choices=[*QUANTIZATION_METHODS, None],
  401. default=EngineArgs.quantization,
  402. help="Method used to quantize the weights. If "
  403. "None, we first check the `quantization_config` "
  404. "attribute in the model config file. If that is "
  405. "None, we assume the model weights are not "
  406. "quantized and use `dtype` to determine the data "
  407. "type of the weights.",
  408. )
  409. parser.add_argument(
  410. "--load-in-4bit",
  411. action="store_true",
  412. help="Load the FP16 model in 4-bit format. Also "
  413. "works with AWQ models. Throughput at 2.5x of "
  414. "FP16.",
  415. )
  416. parser.add_argument(
  417. "--load-in-8bit",
  418. action="store_true",
  419. help="Load the FP16 model in 8-bit format. "
  420. "Throughput at 0.3x of FP16.",
  421. )
  422. parser.add_argument(
  423. "--load-in-smooth",
  424. action="store_true",
  425. help="Load the FP16 model in smoothquant "
  426. "8bit format. Throughput at 0.7x of FP16. ",
  427. )
  428. parser.add_argument(
  429. "--deepspeed-fp-bits",
  430. type=int,
  431. default=None,
  432. help="Number of floating bits to use for the deepseed "
  433. "quantization. Supported bits are: 4, 6, 8, 12. ")
  434. parser.add_argument('--rope-scaling',
  435. default=None,
  436. type=json.loads,
  437. help='RoPE scaling configuration in JSON format. '
  438. 'For example, {"type":"dynamic","factor":2.0}')
  439. parser.add_argument('--rope-theta',
  440. default=None,
  441. type=float,
  442. help='RoPE theta. Use with `rope_scaling`. In '
  443. 'some cases, changing the RoPE theta improves the '
  444. 'performance of the scaled model.')
  445. parser.add_argument(
  446. "--enforce-eager",
  447. type=lambda x: (str(x).lower() == 'true'),
  448. default=EngineArgs.enforce_eager,
  449. help="Always use eager-mode PyTorch. If False, "
  450. "will use eager mode and CUDA graph in hybrid "
  451. "for maximal performance and flexibility.",
  452. )
  453. parser.add_argument("--max-context-len-to-capture",
  454. type=int,
  455. default=EngineArgs.max_context_len_to_capture,
  456. help="Maximum context length covered by CUDA "
  457. "graphs. When a sequence has context length "
  458. "larger than this, we fall back to eager mode. "
  459. "(DEPRECATED. Use --max-seq_len-to-capture instead"
  460. ")")
  461. parser.add_argument("--max-seq_len-to-capture",
  462. type=int,
  463. default=EngineArgs.max_seq_len_to_capture,
  464. help="Maximum sequence length covered by CUDA "
  465. "graphs. When a sequence has context length "
  466. "larger than this, we fall back to eager mode.")
  467. parser.add_argument(
  468. "--disable-custom-all-reduce",
  469. action="store_true",
  470. default=EngineArgs.disable_custom_all_reduce,
  471. help="See ParallelConfig",
  472. )
  473. parser.add_argument("--tokenizer-pool-size",
  474. type=int,
  475. default=EngineArgs.tokenizer_pool_size,
  476. help="Size of tokenizer pool to use for "
  477. "asynchronous tokenization. If 0, will "
  478. "use synchronous tokenization.")
  479. parser.add_argument("--tokenizer-pool-type",
  480. type=str,
  481. default=EngineArgs.tokenizer_pool_type,
  482. help="The type of tokenizer pool to use for "
  483. "asynchronous tokenization. Ignored if "
  484. "tokenizer_pool_size is 0.")
  485. parser.add_argument("--tokenizer-pool-extra-config",
  486. type=str,
  487. default=EngineArgs.tokenizer_pool_extra_config,
  488. help="Extra config for tokenizer pool. "
  489. "This should be a JSON string that will be "
  490. "parsed into a dictionary. Ignored if "
  491. "tokenizer_pool_size is 0.")
  492. parser.add_argument(
  493. '--preemption-mode',
  494. type=str,
  495. default=None,
  496. help='If \'recompute\', the engine performs preemption by block '
  497. 'swapping; If \'swap\', the engine performs preemption by block '
  498. 'swapping.')
  499. # LoRA related configs
  500. parser.add_argument(
  501. "--enable-lora",
  502. action="store_true",
  503. help="If True, enable handling of LoRA adapters.",
  504. )
  505. parser.add_argument(
  506. "--max-loras",
  507. type=int,
  508. default=EngineArgs.max_loras,
  509. help="Max number of LoRAs in a single batch.",
  510. )
  511. parser.add_argument(
  512. "--max-lora-rank",
  513. type=int,
  514. default=EngineArgs.max_lora_rank,
  515. help="Max LoRA rank.",
  516. )
  517. parser.add_argument(
  518. "--lora-extra-vocab-size",
  519. type=int,
  520. default=EngineArgs.lora_extra_vocab_size,
  521. help=("Maximum size of extra vocabulary that can be "
  522. "present in a LoRA adapter (added to the base "
  523. "model vocabulary)."),
  524. )
  525. parser.add_argument(
  526. "--lora-dtype",
  527. type=str,
  528. default=EngineArgs.lora_dtype,
  529. choices=["auto", "float16", "bfloat16", "float32"],
  530. help=("Data type for LoRA. If auto, will default to "
  531. "base model dtype."),
  532. )
  533. parser.add_argument(
  534. "--long-lora-scaling-factors",
  535. type=str,
  536. default=EngineArgs.long_lora_scaling_factors,
  537. help=("Specify multiple scaling factors (which can "
  538. "be different from base model scaling factor "
  539. "- see eg. Long LoRA) to allow for multiple "
  540. "LoRA adapters trained with those scaling "
  541. "factors to be used at the same time. If not "
  542. "specified, only adapters trained with the "
  543. "base model scaling factor are allowed."))
  544. parser.add_argument(
  545. "--max-cpu-loras",
  546. type=int,
  547. default=EngineArgs.max_cpu_loras,
  548. help=("Maximum number of LoRAs to store in CPU memory. "
  549. "Must be >= than max_num_seqs. "
  550. "Defaults to max_num_seqs."),
  551. )
  552. parser.add_argument(
  553. "--fully-sharded-loras",
  554. action='store_true',
  555. help=("By default, only half of the LoRA computation is sharded "
  556. "with tensor parallelism. Enabling this will use the fully "
  557. "sharded layers. At high sequence length, max rank or "
  558. "tensor parallel size, this is likely faster."))
  559. parser.add_argument('--enable-prompt-adapter',
  560. action='store_true',
  561. help='If True, enable handling of PromptAdapters.')
  562. parser.add_argument('--max-prompt-adapters',
  563. type=int,
  564. default=EngineArgs.max_prompt_adapters,
  565. help='Max number of PromptAdapters in a batch.')
  566. parser.add_argument('--max-prompt-adapter-token',
  567. type=int,
  568. default=EngineArgs.max_prompt_adapter_token,
  569. help='Max number of PromptAdapters tokens')
  570. parser.add_argument('--enable-control-vector',
  571. action='store_true',
  572. help='If True, enable handling of ControlVectors.')
  573. parser.add_argument('--max-control-vectors',
  574. type=int,
  575. default=EngineArgs.max_control_vectors,
  576. help='Max number of Control Vectors in a batch.')
  577. parser.add_argument('--normalize-control-vector',
  578. type=bool,
  579. default=EngineArgs.normalize_control_vector,
  580. help='Enable normalization of control vector')
  581. parser.add_argument(
  582. "--device",
  583. type=str,
  584. default=EngineArgs.device,
  585. choices=[
  586. "auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"
  587. ],
  588. help=("Device to use for model execution."),
  589. )
  590. parser.add_argument(
  591. "--scheduler-delay-factor",
  592. "-sdf",
  593. type=float,
  594. default=EngineArgs.scheduler_delay_factor,
  595. help="Apply a delay (of delay factor multiplied by previous "
  596. "prompt latency) before scheduling next prompt.")
  597. parser.add_argument(
  598. "--enable-chunked-prefill",
  599. action="store_true",
  600. help="If True, the prefill requests can be chunked based on the "
  601. "max_num_batched_tokens.")
  602. parser.add_argument(
  603. "--speculative-model",
  604. type=str,
  605. default=EngineArgs.speculative_model,
  606. help=
  607. "The name of the draft model to be used in speculative decoding.")
  608. parser.add_argument(
  609. "--speculative-draft-tensor-parallel-size",
  610. "-spec-draft-tp",
  611. type=int,
  612. default=EngineArgs.speculative_draft_tensor_parallel_size,
  613. help="Number of tensor parallel replicas for "
  614. "the draft model in speculative decoding.")
  615. parser.add_argument(
  616. "--num-speculative-tokens",
  617. type=int,
  618. default=EngineArgs.num_speculative_tokens,
  619. help="The number of speculative tokens to sample from "
  620. "the draft model in speculative decoding")
  621. parser.add_argument(
  622. "--speculative-max-model-len",
  623. type=str,
  624. default=EngineArgs.speculative_max_model_len,
  625. help="The maximum sequence length supported by the "
  626. "draft model. Sequences over this length will skip "
  627. "speculation.")
  628. parser.add_argument(
  629. "--speculative-disable-by-batch-size",
  630. type=int,
  631. default=EngineArgs.speculative_disable_by_batch_size,
  632. help="Disable speculative decoding for new incoming requests "
  633. "if the number of enqueue requests is larger than this value.")
  634. parser.add_argument(
  635. "--ngram-prompt-lookup-max",
  636. type=int,
  637. default=EngineArgs.ngram_prompt_lookup_max,
  638. help="Max size of window for ngram prompt lookup in speculative "
  639. "decoding.")
  640. parser.add_argument(
  641. "--ngram-prompt-lookup-min",
  642. type=int,
  643. default=EngineArgs.ngram_prompt_lookup_min,
  644. help="Min size of window for ngram prompt lookup in speculative "
  645. "decoding.")
  646. parser.add_argument(
  647. '--spec-decoding-acceptance-method',
  648. type=str,
  649. default=EngineArgs.spec_decoding_acceptance_method,
  650. choices=['rejection_sampler', 'typical_acceptance_sampler'],
  651. help='Specify the acceptance method to use during draft token '
  652. 'verification in speculative decoding. Two types of acceptance '
  653. 'routines are supported: '
  654. '1) RejectionSampler which does not allow changing the '
  655. 'acceptance rate of draft tokens, '
  656. '2) TypicalAcceptanceSampler which is configurable, allowing for '
  657. 'a higher acceptance rate at the cost of lower quality, '
  658. 'and vice versa.')
  659. parser.add_argument(
  660. '--typical-acceptance-sampler-posterior-threshold',
  661. type=float,
  662. default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
  663. help='Set the lower bound threshold for the posterior '
  664. 'probability of a token to be accepted. This threshold is '
  665. 'used by the TypicalAcceptanceSampler to make sampling decisions '
  666. 'during speculative decoding. Defaults to 0.09')
  667. parser.add_argument(
  668. '--typical-acceptance-sampler-posterior-alpha',
  669. type=float,
  670. default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
  671. help='A scaling factor for the entropy-based threshold for token '
  672. 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
  673. 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
  674. 'i.e. 0.3')
  675. parser.add_argument(
  676. '--disable-logprobs-during-spec-decoding',
  677. type=bool,
  678. default=EngineArgs.disable_logprobs_during_spec_decoding,
  679. help='If set to True, token log probabilities are not returned '
  680. 'during speculative decoding. If set to False, log probabilities '
  681. 'are returned according to the settings in SamplingParams. If '
  682. 'not specified, it defaults to True. Disabling log probabilities '
  683. 'during speculative decoding reduces latency by skipping logprob '
  684. 'calculation in proposal sampling, target sampling, and after '
  685. 'accepted tokens are determined.')
  686. parser.add_argument("--model-loader-extra-config",
  687. type=str,
  688. default=EngineArgs.model_loader_extra_config,
  689. help="Extra config for model loader. "
  690. "This will be passed to the model loader "
  691. "corresponding to the chosen load_format. "
  692. "This should be a JSON string that will be "
  693. "parsed into a dictionary.")
  694. parser.add_argument(
  695. "--served-model-name",
  696. nargs="+",
  697. type=str,
  698. default=None,
  699. help="The model name(s) used in the API. If multiple "
  700. "names are provided, the server will respond to any "
  701. "of the provided names. The model name in the model "
  702. "field of a response will be the first name in this "
  703. "list. If not specified, the model name will be the "
  704. "same as the `--model` argument. Noted that this name(s)"
  705. "will also be used in `model_name` tag content of "
  706. "prometheus metrics, if multiple names provided, metrics"
  707. "tag will take the first one.")
  708. parser.add_argument("--qlora-adapter-name-or-path",
  709. type=str,
  710. default=None,
  711. help="Name or path of the LoRA adapter to use.")
  712. return parser
  713. @classmethod
  714. def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
  715. # Get the list of attributes of this dataclass.
  716. attrs = [attr.name for attr in dataclasses.fields(cls)]
  717. # Set the attributes from the parsed arguments.
  718. engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
  719. return engine_args
  720. def create_engine_config(self, ) -> EngineConfig:
  721. # bitsandbytes quantization needs a specific model loader
  722. # so we make sure the quant method and the load format are consistent
  723. if (self.quantization == "bitsandbytes" or
  724. self.qlora_adapter_name_or_path is not None) and \
  725. self.load_format != "bitsandbytes":
  726. raise ValueError(
  727. "BitsAndBytes quantization and QLoRA adapter only support "
  728. f"'bitsandbytes' load format, but got {self.load_format}")
  729. if (self.load_format == "bitsandbytes" or
  730. self.qlora_adapter_name_or_path is not None) and \
  731. self.quantization != "bitsandbytes":
  732. raise ValueError(
  733. "BitsAndBytes load format and QLoRA adapter only support "
  734. f"'bitsandbytes' quantization, but got {self.quantization}")
  735. assert self.cpu_offload_gb >= 0, (
  736. "CPU offload space must be non-negative"
  737. f", but got {self.cpu_offload_gb}")
  738. multimodal_config = MultiModalConfig()
  739. device_config = DeviceConfig(device=self.device)
  740. model_config = ModelConfig(
  741. model=self.model,
  742. tokenizer=self.tokenizer,
  743. tokenizer_mode=self.tokenizer_mode,
  744. trust_remote_code=self.trust_remote_code,
  745. dtype=self.dtype,
  746. seed=self.seed,
  747. revision=self.revision,
  748. code_revision=self.code_revision,
  749. rope_scaling=self.rope_scaling,
  750. rope_theta=self.rope_theta,
  751. tokenizer_revision=self.tokenizer_revision,
  752. max_model_len=self.max_model_len,
  753. quantization=self.quantization,
  754. load_in_4bit=self.load_in_4bit,
  755. load_in_8bit=self.load_in_8bit,
  756. load_in_smooth=self.load_in_smooth,
  757. deepspeed_fp_bits=self.deepspeed_fp_bits,
  758. quantization_param_path=self.quantization_param_path,
  759. enforce_eager=self.enforce_eager,
  760. max_context_len_to_capture=self.max_context_len_to_capture,
  761. max_seq_len_to_capture=self.max_seq_len_to_capture,
  762. max_logprobs=self.max_logprobs,
  763. disable_sliding_window=self.disable_sliding_window,
  764. skip_tokenizer_init=self.skip_tokenizer_init,
  765. served_model_name=self.served_model_name,
  766. multimodal_config=multimodal_config,
  767. )
  768. cache_config = CacheConfig(
  769. block_size=self.block_size,
  770. gpu_memory_utilization=self.gpu_memory_utilization,
  771. swap_space=self.swap_space,
  772. cache_dtype=self.kv_cache_dtype,
  773. num_gpu_blocks_override=self.num_gpu_blocks_override,
  774. sliding_window=model_config.get_sliding_window(),
  775. enable_prefix_caching=self.enable_prefix_caching,
  776. cpu_offload_gb=self.cpu_offload_gb,
  777. )
  778. parallel_config = ParallelConfig(
  779. pipeline_parallel_size=self.pipeline_parallel_size,
  780. tensor_parallel_size=self.tensor_parallel_size,
  781. worker_use_ray=self.worker_use_ray,
  782. max_parallel_loading_workers=self.max_parallel_loading_workers,
  783. disable_custom_all_reduce=self.disable_custom_all_reduce,
  784. tokenizer_pool_config=TokenizerPoolConfig.create_config(
  785. tokenizer_pool_size=self.tokenizer_pool_size,
  786. tokenizer_pool_type=self.tokenizer_pool_type,
  787. tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
  788. ),
  789. ray_workers_use_nsight=self.ray_workers_use_nsight,
  790. distributed_executor_backend=self.distributed_executor_backend)
  791. speculative_config = SpeculativeConfig.maybe_create_spec_config(
  792. target_model_config=model_config,
  793. target_parallel_config=parallel_config,
  794. target_dtype=self.dtype,
  795. speculative_model=self.speculative_model,
  796. speculative_draft_tensor_parallel_size=self.
  797. speculative_draft_tensor_parallel_size,
  798. num_speculative_tokens=self.num_speculative_tokens,
  799. speculative_disable_by_batch_size=self.
  800. speculative_disable_by_batch_size,
  801. speculative_max_model_len=self.speculative_max_model_len,
  802. enable_chunked_prefill=self.enable_chunked_prefill,
  803. use_v2_block_manager=self.use_v2_block_manager,
  804. ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
  805. ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
  806. draft_token_acceptance_method=\
  807. self.spec_decoding_acceptance_method,
  808. typical_acceptance_sampler_posterior_threshold=self.
  809. typical_acceptance_sampler_posterior_threshold,
  810. typical_acceptance_sampler_posterior_alpha=self.
  811. typical_acceptance_sampler_posterior_alpha,
  812. disable_logprobs=self.disable_logprobs_during_spec_decoding,
  813. )
  814. scheduler_config = SchedulerConfig(
  815. max_num_batched_tokens=self.max_num_batched_tokens,
  816. max_num_seqs=self.max_num_seqs,
  817. max_model_len=model_config.max_model_len,
  818. use_v2_block_manager=self.use_v2_block_manager,
  819. num_lookahead_slots=(self.num_lookahead_slots
  820. if speculative_config is None else
  821. speculative_config.num_lookahead_slots),
  822. delay_factor=self.scheduler_delay_factor,
  823. enable_chunked_prefill=self.enable_chunked_prefill,
  824. embedding_mode=model_config.embedding_mode,
  825. preemption_mode=self.preemption_mode,
  826. )
  827. lora_config = LoRAConfig(
  828. max_lora_rank=self.max_lora_rank,
  829. max_loras=self.max_loras,
  830. fully_sharded_loras=self.fully_sharded_loras,
  831. lora_extra_vocab_size=self.lora_extra_vocab_size,
  832. long_lora_scaling_factors=self.long_lora_scaling_factors,
  833. lora_dtype=self.lora_dtype,
  834. max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
  835. and self.max_cpu_loras > 0 else None) if self.enable_lora else None
  836. if self.qlora_adapter_name_or_path is not None and \
  837. self.qlora_adapter_name_or_path != "":
  838. if self.model_loader_extra_config is None:
  839. self.model_loader_extra_config = {}
  840. self.model_loader_extra_config[
  841. "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
  842. load_config = LoadConfig(
  843. load_format=self.load_format,
  844. download_dir=self.download_dir,
  845. model_loader_extra_config=self.model_loader_extra_config,
  846. )
  847. prompt_adapter_config = PromptAdapterConfig(
  848. max_prompt_adapters=self.max_prompt_adapters,
  849. max_prompt_adapter_token=self.max_prompt_adapter_token) \
  850. if self.enable_prompt_adapter else None
  851. control_vector_config = ControlVectorConfig(
  852. max_control_vectors=self.max_control_vectors,
  853. normalize=self.normalize_control_vector,
  854. ) if self.enable_control_vector else None
  855. decoding_config = DecodingConfig(
  856. guided_decoding_backend=self.guided_decoding_backend)
  857. if (model_config.get_sliding_window() is not None
  858. and scheduler_config.chunked_prefill_enabled
  859. and not scheduler_config.use_v2_block_manager):
  860. raise ValueError(
  861. "Chunked prefill is not supported with sliding window. "
  862. "Set --disable-sliding-window to disable sliding window.")
  863. return EngineConfig(model_config=model_config,
  864. cache_config=cache_config,
  865. parallel_config=parallel_config,
  866. scheduler_config=scheduler_config,
  867. device_config=device_config,
  868. lora_config=lora_config,
  869. multimodal_config=multimodal_config,
  870. speculative_config=speculative_config,
  871. load_config=load_config,
  872. decoding_config=decoding_config,
  873. prompt_adapter_config=prompt_adapter_config,
  874. control_vector_config=control_vector_config)
  875. @dataclass
  876. class AsyncEngineArgs(EngineArgs):
  877. """Arguments for asynchronous Aphrodite engine."""
  878. engine_use_ray: bool = False
  879. disable_log_requests: bool = False
  880. max_log_len: int = 0
  881. uvloop: bool = False
  882. @staticmethod
  883. def add_cli_args(parser: FlexibleArgumentParser,
  884. async_args_only: bool = False) -> FlexibleArgumentParser:
  885. if not async_args_only:
  886. parser = EngineArgs.add_cli_args(parser)
  887. parser.add_argument('--engine-use-ray',
  888. action='store_true',
  889. help='Use Ray to start the LLM engine in a '
  890. 'separate process as the server process.')
  891. parser.add_argument('--disable-log-requests',
  892. action='store_true',
  893. help='Disable logging requests.')
  894. parser.add_argument('--max-log-len',
  895. type=int,
  896. default=0,
  897. help='Max number of prompt characters or prompt '
  898. 'ID numbers being printed in log.'
  899. '\n\nDefault: Unlimited')
  900. parser.add_argument(
  901. "--uvloop",
  902. action="store_true",
  903. help="Use the Uvloop asyncio event loop to possibly increase "
  904. "performance")
  905. return parser