args_tools.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040
  1. import argparse
  2. import dataclasses
  3. import json
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
  6. from loguru import logger
  7. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  8. EngineConfig, LoadConfig, LoRAConfig,
  9. ModelConfig, MultiModalConfig,
  10. ParallelConfig, PromptAdapterConfig,
  11. SchedulerConfig, SpeculativeConfig,
  12. TokenizerPoolConfig)
  13. from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
  14. from aphrodite.executor.executor_base import ExecutorBase
  15. from aphrodite.quantization import QUANTIZATION_METHODS
  16. from aphrodite.transformers_utils.utils import check_gguf_file
  17. if TYPE_CHECKING:
  18. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import ( # noqa: E501
  19. BaseTokenizerGroup)
  20. @dataclass
  21. class EngineArgs:
  22. """Arguments for Aphrodite engine."""
  23. # Model Options
  24. model: str
  25. seed: int = 0
  26. served_model_name: Optional[Union[List[str]]] = None
  27. tokenizer: Optional[str] = None
  28. revision: Optional[str] = None
  29. code_revision: Optional[str] = None
  30. tokenizer_revision: Optional[str] = None
  31. tokenizer_mode: str = "auto"
  32. trust_remote_code: bool = False
  33. download_dir: Optional[str] = None
  34. max_model_len: Optional[int] = None
  35. max_context_len_to_capture: Optional[int] = None
  36. max_seq_len_to_capture: int = 8192
  37. rope_scaling: Optional[dict] = None
  38. rope_theta: Optional[float] = None
  39. model_loader_extra_config: Optional[dict] = None
  40. enforce_eager: Optional[bool] = True
  41. skip_tokenizer_init: bool = False
  42. tokenizer_pool_size: int = 0
  43. # Note: Specifying a tokenizer pool by passing a class
  44. # is intended for expert use only. The API may change without
  45. # notice.
  46. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
  47. tokenizer_pool_extra_config: Optional[dict] = None
  48. max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
  49. # Device Options
  50. device: str = "auto"
  51. # Load Options
  52. load_format: str = "auto"
  53. dtype: str = "auto"
  54. ignore_patterns: Optional[Union[str, List[str]]] = None
  55. # Parallel Options
  56. worker_use_ray: Optional[bool] = False
  57. tensor_parallel_size: int = 1
  58. pipeline_parallel_size: int = 1
  59. ray_workers_use_nsight: bool = False
  60. disable_custom_all_reduce: bool = False
  61. # Note: Specifying a custom executor backend by passing a class
  62. # is intended for expert use only. The API may change without
  63. # notice.
  64. distributed_executor_backend: Optional[Union[str,
  65. Type[ExecutorBase]]] = None
  66. max_parallel_loading_workers: Optional[int] = None
  67. # Quantization Options
  68. quantization: Optional[str] = None
  69. quantization_param_path: Optional[str] = None
  70. preemption_mode: Optional[str] = None
  71. deepspeed_fp_bits: Optional[int] = None
  72. # Cache Options
  73. kv_cache_dtype: str = "auto"
  74. block_size: int = 16
  75. enable_prefix_caching: Optional[bool] = False
  76. num_gpu_blocks_override: Optional[int] = None
  77. disable_sliding_window: bool = False
  78. gpu_memory_utilization: float = 0.90
  79. swap_space: int = 4 # GiB
  80. cpu_offload_gb: int = 0 # GiB
  81. # Scheduler Options
  82. use_v2_block_manager: bool = False
  83. scheduler_delay_factor: float = 0.0
  84. enable_chunked_prefill: bool = False
  85. guided_decoding_backend: str = 'outlines'
  86. max_num_batched_tokens: Optional[int] = None
  87. max_num_seqs: int = 256
  88. # Speculative Decoding Options
  89. num_lookahead_slots: int = 0
  90. speculative_model: Optional[str] = None
  91. num_speculative_tokens: Optional[int] = None
  92. speculative_max_model_len: Optional[int] = None
  93. ngram_prompt_lookup_max: Optional[int] = None
  94. ngram_prompt_lookup_min: Optional[int] = None
  95. speculative_draft_tensor_parallel_size: Optional[int] = None
  96. speculative_disable_by_batch_size: Optional[int] = None
  97. spec_decoding_acceptance_method: str = 'rejection_sampler'
  98. typical_acceptance_sampler_posterior_threshold: Optional[float] = None
  99. typical_acceptance_sampler_posterior_alpha: Optional[float] = None
  100. disable_logprobs_during_spec_decoding: Optional[bool] = None
  101. # Adapter Options
  102. enable_lora: bool = False
  103. max_loras: int = 1
  104. max_lora_rank: int = 16
  105. lora_extra_vocab_size: int = 256
  106. lora_dtype: str = "auto"
  107. max_cpu_loras: Optional[int] = None
  108. long_lora_scaling_factors: Optional[Tuple[float]] = None
  109. fully_sharded_loras: bool = False
  110. qlora_adapter_name_or_path: Optional[str] = None
  111. enable_prompt_adapter: bool = False
  112. max_prompt_adapters: int = 1
  113. max_prompt_adapter_token: int = 0
  114. # Log Options
  115. disable_log_stats: bool = False
  116. def __post_init__(self):
  117. if self.tokenizer is None:
  118. self.tokenizer = self.model
  119. if is_cpu():
  120. self.distributed_executor_backend = None
  121. @staticmethod
  122. def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
  123. """Shared CLI arguments for the Aphrodite engine."""
  124. # Model Options
  125. parser.add_argument(
  126. "--model",
  127. type=str,
  128. default="EleutherAI/pythia-70m-deduped",
  129. help="Category: Model Options\n"
  130. "name or path of the huggingface model to use",
  131. )
  132. parser.add_argument("--seed",
  133. type=int,
  134. default=EngineArgs.seed,
  135. help="Category: Model Options\n"
  136. "random seed")
  137. parser.add_argument(
  138. "--served-model-name",
  139. nargs="+",
  140. type=str,
  141. default=None,
  142. help="Category: API Options\n"
  143. "The model name(s) used in the API. If multiple "
  144. "names are provided, the server will respond to any "
  145. "of the provided names. The model name in the model "
  146. "field of a response will be the first name in this "
  147. "list. If not specified, the model name will be the "
  148. "same as the `--model` argument. Noted that this name(s)"
  149. "will also be used in `model_name` tag content of "
  150. "prometheus metrics, if multiple names provided, metrics"
  151. "tag will take the first one.")
  152. parser.add_argument(
  153. "--tokenizer",
  154. type=str,
  155. default=EngineArgs.tokenizer,
  156. help="Category: Model Options\n"
  157. "name or path of the huggingface tokenizer to use",
  158. )
  159. parser.add_argument(
  160. "--revision",
  161. type=str,
  162. default=None,
  163. help="Category: Model Options\n"
  164. "the specific model version to use. It can be a branch "
  165. "name, a tag name, or a commit id. If unspecified, will use "
  166. "the default version.",
  167. )
  168. parser.add_argument(
  169. "--code-revision",
  170. type=str,
  171. default=None,
  172. help="Category: Model Options\n"
  173. "the specific revision to use for the model code on "
  174. "Hugging Face Hub. It can be a branch name, a tag name, or a "
  175. "commit id. If unspecified, will use the default version.",
  176. )
  177. parser.add_argument(
  178. "--tokenizer-revision",
  179. type=str,
  180. default=None,
  181. help="Category: Model Options\n"
  182. "the specific tokenizer version to use. It can be a branch "
  183. "name, a tag name, or a commit id. If unspecified, will use "
  184. "the default version.",
  185. )
  186. parser.add_argument(
  187. "--tokenizer-mode",
  188. type=str,
  189. default=EngineArgs.tokenizer_mode,
  190. choices=["auto", "slow"],
  191. help='Category: Model Options\n'
  192. 'tokenizer mode. "auto" will use the fast '
  193. 'tokenizer if available, and "slow" will '
  194. "always use the slow tokenizer.",
  195. )
  196. parser.add_argument(
  197. "--trust-remote-code",
  198. action="store_true",
  199. help="Category: Model Options\n"
  200. "trust remote code from huggingface",
  201. )
  202. parser.add_argument(
  203. "--download-dir",
  204. type=str,
  205. default=EngineArgs.download_dir,
  206. help="Category: Model Options\n"
  207. "directory to download and load the weights, "
  208. "default to the default cache dir of "
  209. "huggingface",
  210. )
  211. parser.add_argument(
  212. "--max-model-len",
  213. type=int,
  214. default=EngineArgs.max_model_len,
  215. help="Category: Model Options\n"
  216. "model context length. If unspecified, "
  217. "will be automatically derived from the model.",
  218. )
  219. parser.add_argument("--max-context-len-to-capture",
  220. type=int,
  221. default=EngineArgs.max_context_len_to_capture,
  222. help="Category: Model Options\n"
  223. "Maximum context length covered by CUDA "
  224. "graphs. When a sequence has context length "
  225. "larger than this, we fall back to eager mode. "
  226. "(DEPRECATED. Use --max-seq_len-to-capture instead"
  227. ")")
  228. parser.add_argument("--max-seq_len-to-capture",
  229. type=int,
  230. default=EngineArgs.max_seq_len_to_capture,
  231. help="Category: Model Options\n"
  232. "Maximum sequence length covered by CUDA "
  233. "graphs. When a sequence has context length "
  234. "larger than this, we fall back to eager mode.")
  235. parser.add_argument('--rope-scaling',
  236. default=None,
  237. type=json.loads,
  238. help='Category: Model Options\n'
  239. 'RoPE scaling configuration in JSON format. '
  240. 'For example, {"type":"dynamic","factor":2.0}')
  241. parser.add_argument('--rope-theta',
  242. default=None,
  243. type=float,
  244. help='Category: Model Options\n'
  245. 'RoPE theta. Use with `rope_scaling`. In '
  246. 'some cases, changing the RoPE theta improves the '
  247. 'performance of the scaled model.')
  248. parser.add_argument("--model-loader-extra-config",
  249. type=str,
  250. default=EngineArgs.model_loader_extra_config,
  251. help="Category: Model Options\n"
  252. "Extra config for model loader. "
  253. "This will be passed to the model loader "
  254. "corresponding to the chosen load_format. "
  255. "This should be a JSON string that will be "
  256. "parsed into a dictionary.")
  257. parser.add_argument(
  258. "--enforce-eager",
  259. action=StoreBoolean,
  260. default=EngineArgs.enforce_eager,
  261. nargs="?",
  262. const="True",
  263. help="Category: Model Options\n"
  264. "Always use eager-mode PyTorch. If False, "
  265. "will use eager mode and CUDA graph in hybrid "
  266. "for maximal performance and flexibility.",
  267. )
  268. parser.add_argument("--skip-tokenizer-init",
  269. action="store_true",
  270. help="Category: Model Options\n"
  271. "Skip initialization of tokenizer and detokenizer")
  272. parser.add_argument("--tokenizer-pool-size",
  273. type=int,
  274. default=EngineArgs.tokenizer_pool_size,
  275. help="Category: Model Options\n"
  276. "Size of tokenizer pool to use for "
  277. "asynchronous tokenization. If 0, will "
  278. "use synchronous tokenization.")
  279. parser.add_argument("--tokenizer-pool-type",
  280. type=str,
  281. default=EngineArgs.tokenizer_pool_type,
  282. help="Category: Model Options\n"
  283. "The type of tokenizer pool to use for "
  284. "asynchronous tokenization. Ignored if "
  285. "tokenizer_pool_size is 0.")
  286. parser.add_argument("--tokenizer-pool-extra-config",
  287. type=str,
  288. default=EngineArgs.tokenizer_pool_extra_config,
  289. help="Category: Model Options\n"
  290. "Extra config for tokenizer pool. "
  291. "This should be a JSON string that will be "
  292. "parsed into a dictionary. Ignored if "
  293. "tokenizer_pool_size is 0.")
  294. parser.add_argument(
  295. "--max-logprobs",
  296. type=int,
  297. default=EngineArgs.max_logprobs,
  298. help="Category: Model Options\n"
  299. "maximum number of log probabilities to "
  300. "return.",
  301. )
  302. # Device Options
  303. parser.add_argument(
  304. "--device",
  305. type=str,
  306. default=EngineArgs.device,
  307. choices=[
  308. "auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"
  309. ],
  310. help=("Category: Model Options\n"
  311. "Device to use for model execution."),
  312. )
  313. # Load Options
  314. parser.add_argument(
  315. '--load-format',
  316. type=str,
  317. default=EngineArgs.load_format,
  318. choices=[
  319. 'auto',
  320. 'pt',
  321. 'safetensors',
  322. 'npcache',
  323. 'dummy',
  324. 'tensorizer',
  325. 'sharded_state',
  326. 'bitsandbytes',
  327. ],
  328. help='Category: Model Options\n'
  329. 'The format of the model weights to load.\n\n'
  330. '* "auto" will try to load the weights in the safetensors format '
  331. 'and fall back to the pytorch bin format if safetensors format '
  332. 'is not available.\n'
  333. '* "pt" will load the weights in the pytorch bin format.\n'
  334. '* "safetensors" will load the weights in the safetensors format.\n'
  335. '* "npcache" will load the weights in pytorch format and store '
  336. 'a numpy cache to speed up the loading.\n'
  337. '* "dummy" will initialize the weights with random values, '
  338. 'which is mainly for profiling.\n'
  339. '* "tensorizer" will load the weights using tensorizer from '
  340. 'CoreWeave. See the Tensorize Aphrodite Model script in the '
  341. 'Examples section for more information.\n'
  342. '* "bitsandbytes" will load the weights using bitsandbytes '
  343. 'quantization.\n')
  344. parser.add_argument(
  345. '--dtype',
  346. type=str,
  347. default=EngineArgs.dtype,
  348. choices=[
  349. 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
  350. ],
  351. help='Category: Model Options\n'
  352. 'Data type for model weights and activations.\n\n'
  353. '* "auto" will use FP16 precision for FP32 and FP16 models, and '
  354. 'BF16 precision for BF16 models.\n'
  355. '* "half" for FP16. Recommended for AWQ quantization.\n'
  356. '* "float16" is the same as "half".\n'
  357. '* "bfloat16" for a balance between precision and range.\n'
  358. '* "float" is shorthand for FP32 precision.\n'
  359. '* "float32" for FP32 precision.')
  360. parser.add_argument(
  361. '--ignore-patterns',
  362. action="append",
  363. type=str,
  364. default=[],
  365. help="Category: Model Options\n"
  366. "The pattern(s) to ignore when loading the model."
  367. "Defaults to 'original/**/*' to avoid repeated loading of llama's "
  368. "checkpoints.")
  369. # Parallel Options
  370. parser.add_argument(
  371. '--worker-use-ray',
  372. action='store_true',
  373. help='Category: Parallel Options\n'
  374. 'Deprecated, use --distributed-executor-backend=ray.')
  375. parser.add_argument(
  376. "--tensor-parallel-size",
  377. "-tp",
  378. type=int,
  379. default=EngineArgs.tensor_parallel_size,
  380. help="Category: Parallel Options\n"
  381. "number of tensor parallel replicas, i.e. the number of GPUs "
  382. "to use.")
  383. parser.add_argument(
  384. "--pipeline-parallel-size",
  385. "-pp",
  386. type=int,
  387. default=EngineArgs.pipeline_parallel_size,
  388. help="Category: Parallel Options\n"
  389. "number of pipeline stages. Currently not supported.")
  390. parser.add_argument(
  391. "--ray-workers-use-nsight",
  392. action="store_true",
  393. help="Category: Parallel Options\n"
  394. "If specified, use nsight to profile ray workers",
  395. )
  396. parser.add_argument(
  397. "--disable-custom-all-reduce",
  398. action="store_true",
  399. default=EngineArgs.disable_custom_all_reduce,
  400. help="Category: Model Options\n"
  401. "See ParallelConfig",
  402. )
  403. parser.add_argument(
  404. '--distributed-executor-backend',
  405. choices=['ray', 'mp'],
  406. default=EngineArgs.distributed_executor_backend,
  407. help='Category: Parallel Options\n'
  408. 'Backend to use for distributed serving. When more than 1 GPU '
  409. 'is used, will be automatically set to "ray" if installed '
  410. 'or "mp" (multiprocessing) otherwise.')
  411. parser.add_argument(
  412. "--max-parallel-loading-workers",
  413. type=int,
  414. default=EngineArgs.max_parallel_loading_workers,
  415. help="Category: Parallel Options\n"
  416. "load model sequentially in multiple batches, "
  417. "to avoid RAM OOM when using tensor "
  418. "parallel and large models",
  419. )
  420. # Quantization Options
  421. parser.add_argument(
  422. "--quantization",
  423. "-q",
  424. type=str,
  425. choices=[*QUANTIZATION_METHODS, None],
  426. default=EngineArgs.quantization,
  427. help="Category: Quantization Options\n"
  428. "Method used to quantize the weights. If "
  429. "None, we first check the `quantization_config` "
  430. "attribute in the model config file. If that is "
  431. "None, we assume the model weights are not "
  432. "quantized and use `dtype` to determine the data "
  433. "type of the weights.",
  434. )
  435. parser.add_argument(
  436. '--quantization-param-path',
  437. type=str,
  438. default=None,
  439. help='Category: Quantization Options\n'
  440. 'Path to the JSON file containing the KV cache '
  441. 'scaling factors. This should generally be supplied, when '
  442. 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
  443. 'default to 1.0, which may cause accuracy issues. '
  444. 'FP8_E5M2 (without scaling) is only supported on cuda version'
  445. 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
  446. 'supported for common inference criteria. ')
  447. parser.add_argument(
  448. '--preemption-mode',
  449. type=str,
  450. default=None,
  451. help='Category: Scheduler Options\n'
  452. 'If \'recompute\', the engine performs preemption by block '
  453. 'swapping; If \'swap\', the engine performs preemption by block '
  454. 'swapping.')
  455. parser.add_argument("--deepspeed-fp-bits",
  456. type=int,
  457. default=None,
  458. help="Category: Quantization Options\n"
  459. "Number of floating bits to use for the deepseed "
  460. "quantization. Supported bits are: 4, 6, 8, 12. ")
  461. # Cache Options
  462. parser.add_argument(
  463. '--kv-cache-dtype',
  464. type=str,
  465. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  466. default=EngineArgs.kv_cache_dtype,
  467. help='Category: Cache Options\n'
  468. 'Data type for kv cache storage. If "auto", will use model '
  469. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  470. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  471. parser.add_argument(
  472. "--block-size",
  473. type=int,
  474. default=EngineArgs.block_size,
  475. choices=[8, 16, 32],
  476. help="Category: Cache Options\n"
  477. "token block size",
  478. )
  479. parser.add_argument(
  480. "--enable-prefix-caching",
  481. "--context-shift",
  482. action="store_true",
  483. help="Category: Cache Options\n"
  484. "Enable automatic prefix caching.",
  485. )
  486. parser.add_argument(
  487. "--num-gpu-blocks-override",
  488. type=int,
  489. default=None,
  490. help="Category: Cache Options Options\n"
  491. "If specified, ignore GPU profiling result and use this "
  492. "number of GPU blocks. Used for testing preemption.")
  493. parser.add_argument('--disable-sliding-window',
  494. action='store_true',
  495. help='Category: KV Cache Options\n'
  496. 'Disables sliding window, '
  497. 'capping to sliding window size')
  498. parser.add_argument(
  499. "--gpu-memory-utilization",
  500. "-gmu",
  501. type=float,
  502. default=EngineArgs.gpu_memory_utilization,
  503. help="Category: Cache Options\n"
  504. "The fraction of GPU memory to be used for "
  505. "the model executor, which can range from 0 to 1."
  506. "If unspecified, will use the default value of 0.9.",
  507. )
  508. parser.add_argument(
  509. "--swap-space",
  510. type=int,
  511. default=EngineArgs.swap_space,
  512. help="Category: Cache Options\n"
  513. "CPU swap space size (GiB) per GPU",
  514. )
  515. parser.add_argument(
  516. '--cpu-offload-gb',
  517. type=float,
  518. default=0,
  519. help='Category: Cache Options\n'
  520. 'The space in GiB to offload to CPU, per GPU. '
  521. 'Default is 0, which means no offloading. Intuitively, '
  522. 'this argument can be seen as a virtual way to increase '
  523. 'the GPU memory size. For example, if you have one 24 GB '
  524. 'GPU and set this to 10, virtually you can think of it as '
  525. 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
  526. 'which requires at least 26GB GPU memory. Note that this '
  527. 'requires fast CPU-GPU interconnect, as part of the model is'
  528. 'loaded from CPU memory to GPU memory on the fly in each '
  529. 'model forward pass.')
  530. # Scheduler Options
  531. parser.add_argument("--use-v2-block-manager",
  532. action="store_true",
  533. help="Category: Scheduler Options\n"
  534. "Use the v2 block manager.")
  535. parser.add_argument(
  536. "--scheduler-delay-factor",
  537. "-sdf",
  538. type=float,
  539. default=EngineArgs.scheduler_delay_factor,
  540. help="Category: Scheduler Options\n"
  541. "Apply a delay (of delay factor multiplied by previous "
  542. "prompt latency) before scheduling next prompt.")
  543. parser.add_argument(
  544. "--enable-chunked-prefill",
  545. action=StoreBoolean,
  546. default=EngineArgs.enable_chunked_prefill,
  547. nargs="?",
  548. const="True",
  549. help="Category: Scheduler Options\n"
  550. "If True, the prefill requests can be chunked based on the "
  551. "max_num_batched_tokens.")
  552. parser.add_argument(
  553. '--guided-decoding-backend',
  554. type=str,
  555. default='outlines',
  556. choices=['outlines', 'lm-format-enforcer'],
  557. help='Category: Scheduler Options\n'
  558. 'Which engine will be used for guided decoding'
  559. ' (JSON schema / regex etc) by default. Currently support '
  560. 'https://github.com/outlines-dev/outlines and '
  561. 'https://github.com/noamgat/lm-format-enforcer.'
  562. ' Can be overridden per request via guided_decoding_backend'
  563. ' parameter.')
  564. parser.add_argument(
  565. "--max-num-batched-tokens",
  566. type=int,
  567. default=EngineArgs.max_num_batched_tokens,
  568. help="Category: KV Cache Options\n"
  569. "maximum number of batched tokens per "
  570. "iteration",
  571. )
  572. parser.add_argument(
  573. "--max-num-seqs",
  574. type=int,
  575. default=EngineArgs.max_num_seqs,
  576. help="Category: API Options\n"
  577. "maximum number of sequences per iteration",
  578. )
  579. # Speculative Decoding Options
  580. parser.add_argument("--num-lookahead-slots",
  581. type=int,
  582. default=EngineArgs.num_lookahead_slots,
  583. help="Category: Speculative Decoding Options\n"
  584. "Experimental scheduling config necessary for "
  585. "speculative decoding. This will be replaced by "
  586. "speculative decoding config in the future; it is "
  587. "present for testing purposes until then.")
  588. parser.add_argument(
  589. "--speculative-model",
  590. type=str,
  591. default=EngineArgs.speculative_model,
  592. help="Category: Speculative Decoding Options\n"
  593. "The name of the draft model to be used in speculative decoding.")
  594. parser.add_argument("--num-speculative-tokens",
  595. type=int,
  596. default=EngineArgs.num_speculative_tokens,
  597. help="Category: Speculative Decoding Options\n"
  598. "The number of speculative tokens to sample from "
  599. "the draft model in speculative decoding")
  600. parser.add_argument(
  601. "--speculative-max-model-len",
  602. type=str,
  603. default=EngineArgs.speculative_max_model_len,
  604. help="Category: Speculative Decoding Options\n"
  605. "The maximum sequence length supported by the "
  606. "draft model. Sequences over this length will skip "
  607. "speculation.")
  608. parser.add_argument(
  609. "--ngram-prompt-lookup-max",
  610. type=int,
  611. default=EngineArgs.ngram_prompt_lookup_max,
  612. help="Category: Speculative Decoding Options\n"
  613. "Max size of window for ngram prompt lookup in speculative "
  614. "decoding.")
  615. parser.add_argument(
  616. "--ngram-prompt-lookup-min",
  617. type=int,
  618. default=EngineArgs.ngram_prompt_lookup_min,
  619. help="Category: Speculative Decoding Options\n"
  620. "Min size of window for ngram prompt lookup in speculative "
  621. "decoding.")
  622. parser.add_argument(
  623. "--speculative-draft-tensor-parallel-size",
  624. "-spec-draft-tp",
  625. type=int,
  626. default=EngineArgs.speculative_draft_tensor_parallel_size,
  627. help="Category: Speculative Decoding Options\n"
  628. "Number of tensor parallel replicas for "
  629. "the draft model in speculative decoding.")
  630. parser.add_argument(
  631. "--speculative-disable-by-batch-size",
  632. type=int,
  633. default=EngineArgs.speculative_disable_by_batch_size,
  634. help="Category: Speculative Decoding Options\n"
  635. "Disable speculative decoding for new incoming requests "
  636. "if the number of enqueue requests is larger than this value.")
  637. parser.add_argument(
  638. '--spec-decoding-acceptance-method',
  639. type=str,
  640. default=EngineArgs.spec_decoding_acceptance_method,
  641. choices=['rejection_sampler', 'typical_acceptance_sampler'],
  642. help='Category: Speculative Decoding Options\n'
  643. 'Specify the acceptance method to use during draft token '
  644. 'verification in speculative decoding. Two types of acceptance '
  645. 'routines are supported: '
  646. '1) RejectionSampler which does not allow changing the '
  647. 'acceptance rate of draft tokens, '
  648. '2) TypicalAcceptanceSampler which is configurable, allowing for '
  649. 'a higher acceptance rate at the cost of lower quality, '
  650. 'and vice versa.')
  651. parser.add_argument(
  652. '--typical-acceptance-sampler-posterior-threshold',
  653. type=float,
  654. default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
  655. help='Category: Speculative Decoding Options\n'
  656. 'Set the lower bound threshold for the posterior '
  657. 'probability of a token to be accepted. This threshold is '
  658. 'used by the TypicalAcceptanceSampler to make sampling decisions '
  659. 'during speculative decoding. Defaults to 0.09')
  660. parser.add_argument(
  661. '--typical-acceptance-sampler-posterior-alpha',
  662. type=float,
  663. default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
  664. help='Category: Speculative Decoding Options\n'
  665. 'A scaling factor for the entropy-based threshold for token '
  666. 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
  667. 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
  668. 'i.e. 0.3')
  669. parser.add_argument(
  670. '--disable-logprobs-during-spec-decoding',
  671. type=bool,
  672. default=EngineArgs.disable_logprobs_during_spec_decoding,
  673. help='Category: Speculative Decoding Options\n'
  674. 'If set to True, token log probabilities are not returned '
  675. 'during speculative decoding. If set to False, log probabilities '
  676. 'are returned according to the settings in SamplingParams. If '
  677. 'not specified, it defaults to True. Disabling log probabilities '
  678. 'during speculative decoding reduces latency by skipping logprob '
  679. 'calculation in proposal sampling, target sampling, and after '
  680. 'accepted tokens are determined.')
  681. # Adapter Options
  682. parser.add_argument(
  683. "--enable-lora",
  684. action="store_true",
  685. help="Category: Adapter Options\n"
  686. "If True, enable handling of LoRA adapters.",
  687. )
  688. parser.add_argument(
  689. "--max-loras",
  690. type=int,
  691. default=EngineArgs.max_loras,
  692. help="Category: Adapter Options\n"
  693. "Max number of LoRAs in a single batch.",
  694. )
  695. parser.add_argument(
  696. "--max-lora-rank",
  697. type=int,
  698. default=EngineArgs.max_lora_rank,
  699. help="Category: Adapter Options\n"
  700. "Max LoRA rank.",
  701. )
  702. parser.add_argument(
  703. "--lora-extra-vocab-size",
  704. type=int,
  705. default=EngineArgs.lora_extra_vocab_size,
  706. help=("Category: Adapter Options\n"
  707. "Maximum size of extra vocabulary that can be "
  708. "present in a LoRA adapter (added to the base "
  709. "model vocabulary)."),
  710. )
  711. parser.add_argument(
  712. "--lora-dtype",
  713. type=str,
  714. default=EngineArgs.lora_dtype,
  715. choices=["auto", "float16", "bfloat16", "float32"],
  716. help=("Category: Adapter Options\n"
  717. "Data type for LoRA. If auto, will default to "
  718. "base model dtype."),
  719. )
  720. parser.add_argument(
  721. "--max-cpu-loras",
  722. type=int,
  723. default=EngineArgs.max_cpu_loras,
  724. help=("Category: Adapter Options\n"
  725. "Maximum number of LoRAs to store in CPU memory. "
  726. "Must be >= than max_num_seqs. "
  727. "Defaults to max_num_seqs."),
  728. )
  729. parser.add_argument(
  730. "--long-lora-scaling-factors",
  731. type=str,
  732. default=EngineArgs.long_lora_scaling_factors,
  733. help=("Category: Adapter Options\n"
  734. "Specify multiple scaling factors (which can "
  735. "be different from base model scaling factor "
  736. "- see eg. Long LoRA) to allow for multiple "
  737. "LoRA adapters trained with those scaling "
  738. "factors to be used at the same time. If not "
  739. "specified, only adapters trained with the "
  740. "base model scaling factor are allowed."))
  741. parser.add_argument(
  742. "--fully-sharded-loras",
  743. action='store_true',
  744. help=("Category: Adapter Options\n"
  745. "By default, only half of the LoRA computation is sharded "
  746. "with tensor parallelism. Enabling this will use the fully "
  747. "sharded layers. At high sequence length, max rank or "
  748. "tensor parallel size, this is likely faster."))
  749. parser.add_argument("--qlora-adapter-name-or-path",
  750. type=str,
  751. default=None,
  752. help="Category: Adapter Options\n"
  753. "Name or path of the LoRA adapter to use.")
  754. parser.add_argument('--enable-prompt-adapter',
  755. action='store_true',
  756. help='Category: Adapter Options\n'
  757. 'If True, enable handling of PromptAdapters.')
  758. parser.add_argument('--max-prompt-adapters',
  759. type=int,
  760. default=EngineArgs.max_prompt_adapters,
  761. help='Category: Adapter Options\n'
  762. 'Max number of PromptAdapters in a batch.')
  763. parser.add_argument('--max-prompt-adapter-token',
  764. type=int,
  765. default=EngineArgs.max_prompt_adapter_token,
  766. help='Category: Adapter Options\n'
  767. 'Max number of PromptAdapters tokens')
  768. # Log Options
  769. parser.add_argument(
  770. "--disable-log-stats",
  771. action="store_true",
  772. help="Category: Log Options\n"
  773. "disable logging statistics",
  774. )
  775. return parser
  776. @classmethod
  777. def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
  778. # Get the list of attributes of this dataclass.
  779. attrs = [attr.name for attr in dataclasses.fields(cls)]
  780. # Set the attributes from the parsed arguments.
  781. engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
  782. return engine_args
  783. def create_engine_config(self, ) -> EngineConfig:
  784. # gguf file needs a specific model loader and doesn't use hf_repo
  785. if check_gguf_file(self.model):
  786. self.quantization = self.load_format = "gguf"
  787. # bitsandbytes quantization needs a specific model loader
  788. # so we make sure the quant method and the load format are consistent
  789. if (self.quantization == "bitsandbytes" or
  790. self.qlora_adapter_name_or_path is not None) and \
  791. self.load_format != "bitsandbytes":
  792. raise ValueError(
  793. "BitsAndBytes quantization and QLoRA adapter only support "
  794. f"'bitsandbytes' load format, but got {self.load_format}")
  795. if (self.load_format == "bitsandbytes" or
  796. self.qlora_adapter_name_or_path is not None) and \
  797. self.quantization != "bitsandbytes":
  798. raise ValueError(
  799. "BitsAndBytes load format and QLoRA adapter only support "
  800. f"'bitsandbytes' quantization, but got {self.quantization}")
  801. assert self.cpu_offload_gb >= 0, (
  802. "CPU offload space must be non-negative"
  803. f", but got {self.cpu_offload_gb}")
  804. multimodal_config = MultiModalConfig()
  805. device_config = DeviceConfig(device=self.device)
  806. model_config = ModelConfig(
  807. model=self.model,
  808. tokenizer=self.tokenizer,
  809. tokenizer_mode=self.tokenizer_mode,
  810. trust_remote_code=self.trust_remote_code,
  811. dtype=self.dtype,
  812. seed=self.seed,
  813. revision=self.revision,
  814. code_revision=self.code_revision,
  815. rope_scaling=self.rope_scaling,
  816. rope_theta=self.rope_theta,
  817. tokenizer_revision=self.tokenizer_revision,
  818. max_model_len=self.max_model_len,
  819. quantization=self.quantization,
  820. deepspeed_fp_bits=self.deepspeed_fp_bits,
  821. quantization_param_path=self.quantization_param_path,
  822. enforce_eager=self.enforce_eager,
  823. max_context_len_to_capture=self.max_context_len_to_capture,
  824. max_seq_len_to_capture=self.max_seq_len_to_capture,
  825. max_logprobs=self.max_logprobs,
  826. disable_sliding_window=self.disable_sliding_window,
  827. skip_tokenizer_init=self.skip_tokenizer_init,
  828. served_model_name=self.served_model_name,
  829. multimodal_config=multimodal_config,
  830. )
  831. cache_config = CacheConfig(
  832. block_size=self.block_size,
  833. gpu_memory_utilization=self.gpu_memory_utilization,
  834. swap_space=self.swap_space,
  835. cache_dtype=self.kv_cache_dtype,
  836. num_gpu_blocks_override=self.num_gpu_blocks_override,
  837. sliding_window=model_config.get_sliding_window(),
  838. enable_prefix_caching=self.enable_prefix_caching,
  839. cpu_offload_gb=self.cpu_offload_gb,
  840. )
  841. parallel_config = ParallelConfig(
  842. pipeline_parallel_size=self.pipeline_parallel_size,
  843. tensor_parallel_size=self.tensor_parallel_size,
  844. worker_use_ray=self.worker_use_ray,
  845. max_parallel_loading_workers=self.max_parallel_loading_workers,
  846. disable_custom_all_reduce=self.disable_custom_all_reduce,
  847. tokenizer_pool_config=TokenizerPoolConfig.create_config(
  848. tokenizer_pool_size=self.tokenizer_pool_size,
  849. tokenizer_pool_type=self.tokenizer_pool_type,
  850. tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
  851. ),
  852. ray_workers_use_nsight=self.ray_workers_use_nsight,
  853. distributed_executor_backend=self.distributed_executor_backend)
  854. max_model_len = model_config.max_model_len
  855. use_long_context = max_model_len > 32768
  856. if self.enable_chunked_prefill is None:
  857. # If not explicitly set, enable chunked prefill by default for
  858. # long context (> 32K) models. This is to avoid OOM errors in the
  859. # initial memory profiling phase.
  860. if use_long_context:
  861. is_gpu = device_config.device_type == "cuda"
  862. use_sliding_window = (model_config.get_sliding_window()
  863. is not None)
  864. use_spec_decode = self.speculative_model is not None
  865. has_seqlen_agnostic_layers = (
  866. model_config.contains_seqlen_agnostic_layers(
  867. parallel_config))
  868. if (is_gpu and not use_sliding_window and not use_spec_decode
  869. and not self.enable_lora
  870. and not self.enable_prompt_adapter
  871. and not self.enable_prefix_caching
  872. and not has_seqlen_agnostic_layers):
  873. self.enable_chunked_prefill = True
  874. logger.warning(
  875. "Chunked prefill is enabled by default for models with "
  876. "max_model_len > 32K. Currently, chunked prefill might "
  877. "not work with some features or models. If you "
  878. "encounter any issues, please disable chunked prefill "
  879. "by setting --enable-chunked-prefill=False.")
  880. if self.enable_chunked_prefill is None:
  881. self.enable_chunked_prefill = False
  882. if not self.enable_chunked_prefill and use_long_context:
  883. logger.warning(
  884. f"The model has a long context length ({max_model_len}). "
  885. "This may cause OOM errors during the initial memory "
  886. "profiling phase, or result in low performance due to small "
  887. "KV cache space. Consider setting --max-model-len to a "
  888. "smaller value.")
  889. speculative_config = SpeculativeConfig.maybe_create_spec_config(
  890. target_model_config=model_config,
  891. target_parallel_config=parallel_config,
  892. target_dtype=self.dtype,
  893. speculative_model=self.speculative_model,
  894. speculative_draft_tensor_parallel_size=self.
  895. speculative_draft_tensor_parallel_size,
  896. num_speculative_tokens=self.num_speculative_tokens,
  897. speculative_disable_by_batch_size=self.
  898. speculative_disable_by_batch_size,
  899. speculative_max_model_len=self.speculative_max_model_len,
  900. enable_chunked_prefill=self.enable_chunked_prefill,
  901. use_v2_block_manager=self.use_v2_block_manager,
  902. disable_log_stats=self.disable_log_stats,
  903. ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
  904. ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
  905. draft_token_acceptance_method=\
  906. self.spec_decoding_acceptance_method,
  907. typical_acceptance_sampler_posterior_threshold=self.
  908. typical_acceptance_sampler_posterior_threshold,
  909. typical_acceptance_sampler_posterior_alpha=self.
  910. typical_acceptance_sampler_posterior_alpha,
  911. disable_logprobs=self.disable_logprobs_during_spec_decoding,
  912. )
  913. scheduler_config = SchedulerConfig(
  914. max_num_batched_tokens=self.max_num_batched_tokens,
  915. max_num_seqs=self.max_num_seqs,
  916. max_model_len=model_config.max_model_len,
  917. use_v2_block_manager=self.use_v2_block_manager,
  918. num_lookahead_slots=(self.num_lookahead_slots
  919. if speculative_config is None else
  920. speculative_config.num_lookahead_slots),
  921. delay_factor=self.scheduler_delay_factor,
  922. enable_chunked_prefill=self.enable_chunked_prefill,
  923. embedding_mode=model_config.embedding_mode,
  924. preemption_mode=self.preemption_mode,
  925. )
  926. lora_config = LoRAConfig(
  927. max_lora_rank=self.max_lora_rank,
  928. max_loras=self.max_loras,
  929. fully_sharded_loras=self.fully_sharded_loras,
  930. lora_extra_vocab_size=self.lora_extra_vocab_size,
  931. long_lora_scaling_factors=self.long_lora_scaling_factors,
  932. lora_dtype=self.lora_dtype,
  933. max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
  934. and self.max_cpu_loras > 0 else None) if self.enable_lora else None
  935. if self.qlora_adapter_name_or_path is not None and \
  936. self.qlora_adapter_name_or_path != "":
  937. if self.model_loader_extra_config is None:
  938. self.model_loader_extra_config = {}
  939. self.model_loader_extra_config[
  940. "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
  941. load_config = LoadConfig(
  942. load_format=self.load_format,
  943. download_dir=self.download_dir,
  944. model_loader_extra_config=self.model_loader_extra_config,
  945. ignore_patterns=self.ignore_patterns)
  946. prompt_adapter_config = PromptAdapterConfig(
  947. max_prompt_adapters=self.max_prompt_adapters,
  948. max_prompt_adapter_token=self.max_prompt_adapter_token) \
  949. if self.enable_prompt_adapter else None
  950. decoding_config = DecodingConfig(
  951. guided_decoding_backend=self.guided_decoding_backend)
  952. if (model_config.get_sliding_window() is not None
  953. and scheduler_config.chunked_prefill_enabled
  954. and not scheduler_config.use_v2_block_manager):
  955. raise ValueError(
  956. "Chunked prefill is not supported with sliding window. "
  957. "Set --disable-sliding-window to disable sliding window.")
  958. return EngineConfig(model_config=model_config,
  959. cache_config=cache_config,
  960. parallel_config=parallel_config,
  961. scheduler_config=scheduler_config,
  962. device_config=device_config,
  963. lora_config=lora_config,
  964. multimodal_config=multimodal_config,
  965. speculative_config=speculative_config,
  966. load_config=load_config,
  967. decoding_config=decoding_config,
  968. prompt_adapter_config=prompt_adapter_config)
  969. @dataclass
  970. class AsyncEngineArgs(EngineArgs):
  971. """Arguments for asynchronous Aphrodite engine."""
  972. engine_use_ray: bool = False
  973. disable_log_requests: bool = False
  974. uvloop: bool = False
  975. @staticmethod
  976. def add_cli_args(parser: FlexibleArgumentParser,
  977. async_args_only: bool = False) -> FlexibleArgumentParser:
  978. if not async_args_only:
  979. parser = EngineArgs.add_cli_args(parser)
  980. parser.add_argument('--engine-use-ray',
  981. action='store_true',
  982. help='Use Ray to start the LLM engine in a '
  983. 'separate process as the server process.')
  984. parser.add_argument('--disable-log-requests',
  985. action='store_true',
  986. help='Disable logging requests.')
  987. parser.add_argument(
  988. "--uvloop",
  989. action="store_true",
  990. help="Use the Uvloop asyncio event loop to possibly increase "
  991. "performance")
  992. return parser
  993. class StoreBoolean(argparse.Action):
  994. def __call__(self, parser, namespace, values, option_string=None):
  995. if values.lower() == "true":
  996. setattr(namespace, self.dest, True)
  997. elif values.lower() == "false":
  998. setattr(namespace, self.dest, False)
  999. else:
  1000. raise ValueError(f"Invalid boolean value: {values}. "
  1001. "Expected 'true' or 'false'.")