args_tools.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. import argparse
  2. import dataclasses
  3. import json
  4. from dataclasses import dataclass
  5. from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
  6. Union)
  7. from loguru import logger
  8. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  9. EngineConfig, LoadConfig, LoRAConfig,
  10. ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig, TokenizerPoolConfig)
  13. from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
  14. from aphrodite.executor.executor_base import ExecutorBase
  15. from aphrodite.quantization import QUANTIZATION_METHODS
  16. from aphrodite.transformers_utils.utils import check_gguf_file
  17. if TYPE_CHECKING:
  18. from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
  19. def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
  20. if len(val) == 0:
  21. return None
  22. out_dict: Dict[str, int] = {}
  23. for item in val.split(","):
  24. try:
  25. key, value = item.split("=")
  26. except TypeError as exc:
  27. msg = "Each item should be in the form KEY=VALUE"
  28. raise ValueError(msg) from exc
  29. try:
  30. out_dict[key] = int(value)
  31. except ValueError as exc:
  32. msg = f"Failed to parse value of item {key}={value}"
  33. raise ValueError(msg) from exc
  34. return out_dict
  35. @dataclass
  36. class EngineArgs:
  37. """Arguments for Aphrodite engine."""
  38. # Model Options
  39. model: str
  40. seed: int = 0
  41. served_model_name: Optional[Union[str, List[str]]] = None
  42. tokenizer: Optional[str] = None
  43. revision: Optional[str] = None
  44. code_revision: Optional[str] = None
  45. tokenizer_revision: Optional[str] = None
  46. tokenizer_mode: str = "auto"
  47. trust_remote_code: bool = False
  48. download_dir: Optional[str] = None
  49. max_model_len: Optional[int] = None
  50. max_context_len_to_capture: Optional[int] = None
  51. max_seq_len_to_capture: int = 8192
  52. rope_scaling: Optional[dict] = None
  53. rope_theta: Optional[float] = None
  54. model_loader_extra_config: Optional[dict] = None
  55. enforce_eager: Optional[bool] = None
  56. skip_tokenizer_init: bool = False
  57. tokenizer_pool_size: int = 0
  58. # Note: Specifying a tokenizer pool by passing a class
  59. # is intended for expert use only. The API may change without
  60. # notice.
  61. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
  62. tokenizer_pool_extra_config: Optional[dict] = None
  63. limit_mm_per_prompt: Optional[Mapping[str, int]] = None
  64. max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
  65. # Device Options
  66. device: str = "auto"
  67. # Load Options
  68. load_format: str = "auto"
  69. dtype: str = "auto"
  70. ignore_patterns: Optional[Union[str, List[str]]] = None
  71. # Parallel Options
  72. worker_use_ray: Optional[bool] = False
  73. tensor_parallel_size: int = 1
  74. pipeline_parallel_size: int = 1
  75. ray_workers_use_nsight: bool = False
  76. disable_custom_all_reduce: bool = False
  77. # Note: Specifying a custom executor backend by passing a class
  78. # is intended for expert use only. The API may change without
  79. # notice.
  80. distributed_executor_backend: Optional[Union[str,
  81. Type[ExecutorBase]]] = None
  82. max_parallel_loading_workers: Optional[int] = None
  83. # Quantization Options
  84. quantization: Optional[str] = None
  85. quantization_param_path: Optional[str] = None
  86. preemption_mode: Optional[str] = None
  87. deepspeed_fp_bits: Optional[int] = None
  88. # Cache Options
  89. kv_cache_dtype: str = "auto"
  90. block_size: int = 16
  91. enable_prefix_caching: Optional[bool] = False
  92. num_gpu_blocks_override: Optional[int] = None
  93. disable_sliding_window: bool = False
  94. gpu_memory_utilization: float = 0.90
  95. swap_space: float = 4 # GiB
  96. cpu_offload_gb: float = 0 # GiB
  97. # Scheduler Options
  98. use_v2_block_manager: bool = False
  99. scheduler_delay_factor: float = 0.0
  100. enable_chunked_prefill: bool = False
  101. guided_decoding_backend: str = 'outlines'
  102. max_num_batched_tokens: Optional[int] = None
  103. max_num_seqs: int = 256
  104. num_scheduler_steps: int = 1
  105. # Speculative Decoding Options
  106. num_lookahead_slots: int = 0
  107. speculative_model: Optional[str] = None
  108. speculative_model_quantization: Optional[str] = None
  109. num_speculative_tokens: Optional[int] = None
  110. speculative_max_model_len: Optional[int] = None
  111. ngram_prompt_lookup_max: Optional[int] = None
  112. ngram_prompt_lookup_min: Optional[int] = None
  113. speculative_draft_tensor_parallel_size: Optional[int] = None
  114. speculative_disable_by_batch_size: Optional[int] = None
  115. spec_decoding_acceptance_method: str = 'rejection_sampler'
  116. typical_acceptance_sampler_posterior_threshold: Optional[float] = None
  117. typical_acceptance_sampler_posterior_alpha: Optional[float] = None
  118. disable_logprobs_during_spec_decoding: Optional[bool] = None
  119. # Adapter Options
  120. enable_lora: bool = False
  121. max_loras: int = 1
  122. max_lora_rank: int = 16
  123. lora_extra_vocab_size: int = 256
  124. lora_dtype: str = "auto"
  125. max_cpu_loras: Optional[int] = None
  126. long_lora_scaling_factors: Optional[Tuple[float]] = None
  127. fully_sharded_loras: bool = False
  128. qlora_adapter_name_or_path: Optional[str] = None
  129. enable_prompt_adapter: bool = False
  130. max_prompt_adapters: int = 1
  131. max_prompt_adapter_token: int = 0
  132. # Log Options
  133. disable_log_stats: bool = False
  134. def __post_init__(self):
  135. if self.tokenizer is None:
  136. self.tokenizer = self.model
  137. if is_cpu():
  138. self.distributed_executor_backend = None
  139. @staticmethod
  140. def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
  141. """Shared CLI arguments for the Aphrodite engine."""
  142. # Model Options
  143. parser.add_argument(
  144. "--model",
  145. type=str,
  146. default="EleutherAI/pythia-70m-deduped",
  147. help="Category: Model Options\n"
  148. "name or path of the huggingface model to use",
  149. )
  150. parser.add_argument("--seed",
  151. type=int,
  152. default=EngineArgs.seed,
  153. help="Category: Model Options\n"
  154. "random seed")
  155. parser.add_argument(
  156. "--served-model-name",
  157. nargs="+",
  158. type=str,
  159. default=None,
  160. help="Category: API Options\n"
  161. "The model name(s) used in the API. If multiple "
  162. "names are provided, the server will respond to any "
  163. "of the provided names. The model name in the model "
  164. "field of a response will be the first name in this "
  165. "list. If not specified, the model name will be the "
  166. "same as the `--model` argument. Noted that this name(s)"
  167. "will also be used in `model_name` tag content of "
  168. "prometheus metrics, if multiple names provided, metrics"
  169. "tag will take the first one.")
  170. parser.add_argument(
  171. "--tokenizer",
  172. type=str,
  173. default=EngineArgs.tokenizer,
  174. help="Category: Model Options\n"
  175. "name or path of the huggingface tokenizer to use",
  176. )
  177. parser.add_argument(
  178. "--revision",
  179. type=str,
  180. default=None,
  181. help="Category: Model Options\n"
  182. "the specific model version to use. It can be a branch "
  183. "name, a tag name, or a commit id. If unspecified, will use "
  184. "the default version.",
  185. )
  186. parser.add_argument(
  187. "--code-revision",
  188. type=str,
  189. default=None,
  190. help="Category: Model Options\n"
  191. "the specific revision to use for the model code on "
  192. "Hugging Face Hub. It can be a branch name, a tag name, or a "
  193. "commit id. If unspecified, will use the default version.",
  194. )
  195. parser.add_argument(
  196. "--tokenizer-revision",
  197. type=str,
  198. default=None,
  199. help="Category: Model Options\n"
  200. "the specific tokenizer version to use. It can be a branch "
  201. "name, a tag name, or a commit id. If unspecified, will use "
  202. "the default version.",
  203. )
  204. parser.add_argument(
  205. "--tokenizer-mode",
  206. type=str,
  207. default=EngineArgs.tokenizer_mode,
  208. choices=['auto', 'slow', 'mistral'],
  209. help='The tokenizer mode.\n\n* "auto" will use the '
  210. 'fast tokenizer if available.\n* "slow" will '
  211. 'always use the slow tokenizer. \n* '
  212. '"mistral" will always use the `mistral_common` tokenizer.')
  213. parser.add_argument(
  214. "--trust-remote-code",
  215. action="store_true",
  216. help="Category: Model Options\n"
  217. "trust remote code from huggingface",
  218. )
  219. parser.add_argument(
  220. "--download-dir",
  221. type=str,
  222. default=EngineArgs.download_dir,
  223. help="Category: Model Options\n"
  224. "directory to download and load the weights, "
  225. "default to the default cache dir of "
  226. "huggingface",
  227. )
  228. parser.add_argument(
  229. "--max-model-len",
  230. type=int,
  231. default=EngineArgs.max_model_len,
  232. help="Category: Model Options\n"
  233. "model context length. If unspecified, "
  234. "will be automatically derived from the model.",
  235. )
  236. parser.add_argument("--max-context-len-to-capture",
  237. type=int,
  238. default=EngineArgs.max_context_len_to_capture,
  239. help="Category: Model Options\n"
  240. "Maximum context length covered by CUDA "
  241. "graphs. When a sequence has context length "
  242. "larger than this, we fall back to eager mode. "
  243. "(DEPRECATED. Use --max-seq_len-to-capture instead"
  244. ")")
  245. parser.add_argument("--max-seq_len-to-capture",
  246. type=int,
  247. default=EngineArgs.max_seq_len_to_capture,
  248. help="Category: Model Options\n"
  249. "Maximum sequence length covered by CUDA "
  250. "graphs. When a sequence has context length "
  251. "larger than this, we fall back to eager mode.")
  252. parser.add_argument('--rope-scaling',
  253. default=None,
  254. type=json.loads,
  255. help='Category: Model Options\n'
  256. 'RoPE scaling configuration in JSON format. '
  257. 'For example, {"type":"dynamic","factor":2.0}')
  258. parser.add_argument('--rope-theta',
  259. default=None,
  260. type=float,
  261. help='Category: Model Options\n'
  262. 'RoPE theta. Use with `rope_scaling`. In '
  263. 'some cases, changing the RoPE theta improves the '
  264. 'performance of the scaled model.')
  265. parser.add_argument("--model-loader-extra-config",
  266. type=str,
  267. default=EngineArgs.model_loader_extra_config,
  268. help="Category: Model Options\n"
  269. "Extra config for model loader. "
  270. "This will be passed to the model loader "
  271. "corresponding to the chosen load_format. "
  272. "This should be a JSON string that will be "
  273. "parsed into a dictionary.")
  274. parser.add_argument(
  275. "--enforce-eager",
  276. action=StoreBoolean,
  277. default=EngineArgs.enforce_eager,
  278. nargs="?",
  279. const="True",
  280. help="Category: Model Options\n"
  281. "Always use eager-mode PyTorch. If False, "
  282. "will use eager mode and CUDA graph in hybrid "
  283. "for maximal performance and flexibility.",
  284. )
  285. parser.add_argument("--skip-tokenizer-init",
  286. action="store_true",
  287. help="Category: Model Options\n"
  288. "Skip initialization of tokenizer and detokenizer")
  289. parser.add_argument("--tokenizer-pool-size",
  290. type=int,
  291. default=EngineArgs.tokenizer_pool_size,
  292. help="Category: Model Options\n"
  293. "Size of tokenizer pool to use for "
  294. "asynchronous tokenization. If 0, will "
  295. "use synchronous tokenization.")
  296. parser.add_argument("--tokenizer-pool-type",
  297. type=str,
  298. default=EngineArgs.tokenizer_pool_type,
  299. help="Category: Model Options\n"
  300. "The type of tokenizer pool to use for "
  301. "asynchronous tokenization. Ignored if "
  302. "tokenizer_pool_size is 0.")
  303. parser.add_argument("--tokenizer-pool-extra-config",
  304. type=str,
  305. default=EngineArgs.tokenizer_pool_extra_config,
  306. help="Category: Model Options\n"
  307. "Extra config for tokenizer pool. "
  308. "This should be a JSON string that will be "
  309. "parsed into a dictionary. Ignored if "
  310. "tokenizer_pool_size is 0.")
  311. # Multimodal related configs
  312. parser.add_argument(
  313. '--limit-mm-per-prompt',
  314. type=nullable_kvs,
  315. default=EngineArgs.limit_mm_per_prompt,
  316. # The default value is given in
  317. # MultiModalRegistry.init_mm_limits_per_prompt
  318. help=('For each multimodal plugin, limit how many '
  319. 'input instances to allow for each prompt. '
  320. 'Expects a comma-separated list of items, '
  321. 'e.g.: `image=16,video=2` allows a maximum of 16 '
  322. 'images and 2 videos per prompt. Defaults to 1 for '
  323. 'each modality.'))
  324. parser.add_argument(
  325. "--max-logprobs",
  326. type=int,
  327. default=EngineArgs.max_logprobs,
  328. help="Category: Model Options\n"
  329. "maximum number of log probabilities to "
  330. "return.",
  331. )
  332. # Device Options
  333. parser.add_argument(
  334. "--device",
  335. type=str,
  336. default=EngineArgs.device,
  337. choices=[
  338. "auto", "cuda", "neuron", "cpu", "openvino", "tpu", "xpu"
  339. ],
  340. help=("Category: Model Options\n"
  341. "Device to use for model execution."),
  342. )
  343. # Load Options
  344. parser.add_argument(
  345. '--load-format',
  346. type=str,
  347. default=EngineArgs.load_format,
  348. choices=[
  349. 'auto',
  350. 'pt',
  351. 'safetensors',
  352. 'npcache',
  353. 'dummy',
  354. 'tensorizer',
  355. 'sharded_state',
  356. 'bitsandbytes',
  357. ],
  358. help='Category: Model Options\n'
  359. 'The format of the model weights to load.\n\n'
  360. '* "auto" will try to load the weights in the safetensors format '
  361. 'and fall back to the pytorch bin format if safetensors format '
  362. 'is not available.\n'
  363. '* "pt" will load the weights in the pytorch bin format.\n'
  364. '* "safetensors" will load the weights in the safetensors format.\n'
  365. '* "npcache" will load the weights in pytorch format and store '
  366. 'a numpy cache to speed up the loading.\n'
  367. '* "dummy" will initialize the weights with random values, '
  368. 'which is mainly for profiling.\n'
  369. '* "tensorizer" will load the weights using tensorizer from '
  370. 'CoreWeave. See the Tensorize Aphrodite Model script in the '
  371. 'Examples section for more information.\n'
  372. '* "bitsandbytes" will load the weights using bitsandbytes '
  373. 'quantization.\n')
  374. parser.add_argument(
  375. '--dtype',
  376. type=str,
  377. default=EngineArgs.dtype,
  378. choices=[
  379. 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
  380. ],
  381. help='Category: Model Options\n'
  382. 'Data type for model weights and activations.\n\n'
  383. '* "auto" will use FP16 precision for FP32 and FP16 models, and '
  384. 'BF16 precision for BF16 models.\n'
  385. '* "half" for FP16. Recommended for AWQ quantization.\n'
  386. '* "float16" is the same as "half".\n'
  387. '* "bfloat16" for a balance between precision and range.\n'
  388. '* "float" is shorthand for FP32 precision.\n'
  389. '* "float32" for FP32 precision.')
  390. parser.add_argument(
  391. '--ignore-patterns',
  392. action="append",
  393. type=str,
  394. default=[],
  395. help="Category: Model Options\n"
  396. "The pattern(s) to ignore when loading the model."
  397. "Defaults to 'original/**/*' to avoid repeated loading of llama's "
  398. "checkpoints.")
  399. # Parallel Options
  400. parser.add_argument(
  401. '--worker-use-ray',
  402. action='store_true',
  403. help='Category: Parallel Options\n'
  404. 'Deprecated, use --distributed-executor-backend=ray.')
  405. parser.add_argument(
  406. "--tensor-parallel-size",
  407. "-tp",
  408. type=int,
  409. default=EngineArgs.tensor_parallel_size,
  410. help="Category: Parallel Options\n"
  411. "number of tensor parallel replicas, i.e. the number of GPUs "
  412. "to use.")
  413. parser.add_argument(
  414. "--pipeline-parallel-size",
  415. "-pp",
  416. type=int,
  417. default=EngineArgs.pipeline_parallel_size,
  418. help="Category: Parallel Options\n"
  419. "number of pipeline stages. Currently not supported.")
  420. parser.add_argument(
  421. "--ray-workers-use-nsight",
  422. action="store_true",
  423. help="Category: Parallel Options\n"
  424. "If specified, use nsight to profile ray workers",
  425. )
  426. parser.add_argument(
  427. "--disable-custom-all-reduce",
  428. action="store_true",
  429. default=EngineArgs.disable_custom_all_reduce,
  430. help="Category: Model Options\n"
  431. "See ParallelConfig",
  432. )
  433. parser.add_argument(
  434. '--distributed-executor-backend',
  435. choices=['ray', 'mp'],
  436. default=EngineArgs.distributed_executor_backend,
  437. help='Category: Parallel Options\n'
  438. 'Backend to use for distributed serving. When more than 1 GPU '
  439. 'is used, will be automatically set to "ray" if installed '
  440. 'or "mp" (multiprocessing) otherwise.')
  441. parser.add_argument(
  442. "--max-parallel-loading-workers",
  443. type=int,
  444. default=EngineArgs.max_parallel_loading_workers,
  445. help="Category: Parallel Options\n"
  446. "load model sequentially in multiple batches, "
  447. "to avoid RAM OOM when using tensor "
  448. "parallel and large models",
  449. )
  450. # Quantization Options
  451. parser.add_argument(
  452. "--quantization",
  453. "-q",
  454. type=str,
  455. choices=[*QUANTIZATION_METHODS, None],
  456. default=EngineArgs.quantization,
  457. help="Category: Quantization Options\n"
  458. "Method used to quantize the weights. If "
  459. "None, we first check the `quantization_config` "
  460. "attribute in the model config file. If that is "
  461. "None, we assume the model weights are not "
  462. "quantized and use `dtype` to determine the data "
  463. "type of the weights.",
  464. )
  465. parser.add_argument(
  466. '--quantization-param-path',
  467. type=str,
  468. default=None,
  469. help='Category: Quantization Options\n'
  470. 'Path to the JSON file containing the KV cache '
  471. 'scaling factors. This should generally be supplied, when '
  472. 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
  473. 'default to 1.0, which may cause accuracy issues. '
  474. 'FP8_E5M2 (without scaling) is only supported on cuda version'
  475. 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
  476. 'supported for common inference criteria. ')
  477. parser.add_argument(
  478. '--preemption-mode',
  479. type=str,
  480. default=None,
  481. help='Category: Scheduler Options\n'
  482. 'If \'recompute\', the engine performs preemption by block '
  483. 'swapping; If \'swap\', the engine performs preemption by block '
  484. 'swapping.')
  485. parser.add_argument("--deepspeed-fp-bits",
  486. type=int,
  487. default=None,
  488. help="Category: Quantization Options\n"
  489. "Number of floating bits to use for the deepseed "
  490. "quantization. Supported bits are: 4, 6, 8, 12. ")
  491. # Cache Options
  492. parser.add_argument(
  493. '--kv-cache-dtype',
  494. type=str,
  495. choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
  496. default=EngineArgs.kv_cache_dtype,
  497. help='Category: Cache Options\n'
  498. 'Data type for kv cache storage. If "auto", will use model '
  499. 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
  500. 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
  501. parser.add_argument(
  502. "--block-size",
  503. type=int,
  504. default=EngineArgs.block_size,
  505. choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
  506. help="Category: Cache Options\n"
  507. "token block size",
  508. )
  509. parser.add_argument(
  510. "--enable-prefix-caching",
  511. "--context-shift",
  512. action="store_true",
  513. help="Category: Cache Options\n"
  514. "Enable automatic prefix caching.",
  515. )
  516. parser.add_argument(
  517. "--num-gpu-blocks-override",
  518. type=int,
  519. default=None,
  520. help="Category: Cache Options Options\n"
  521. "If specified, ignore GPU profiling result and use this "
  522. "number of GPU blocks. Used for testing preemption.")
  523. parser.add_argument('--disable-sliding-window',
  524. action='store_true',
  525. help='Category: KV Cache Options\n'
  526. 'Disables sliding window, '
  527. 'capping to sliding window size')
  528. parser.add_argument(
  529. "--gpu-memory-utilization",
  530. "-gmu",
  531. type=float,
  532. default=EngineArgs.gpu_memory_utilization,
  533. help="Category: Cache Options\n"
  534. "The fraction of GPU memory to be used for "
  535. "the model executor, which can range from 0 to 1."
  536. "If unspecified, will use the default value of 0.9.",
  537. )
  538. parser.add_argument(
  539. "--swap-space",
  540. type=float,
  541. default=EngineArgs.swap_space,
  542. help="Category: Cache Options\n"
  543. "CPU swap space size (GiB) per GPU",
  544. )
  545. parser.add_argument(
  546. '--cpu-offload-gb',
  547. type=float,
  548. default=0,
  549. help='Category: Cache Options\n'
  550. 'The space in GiB to offload to CPU, per GPU. '
  551. 'Default is 0, which means no offloading. Intuitively, '
  552. 'this argument can be seen as a virtual way to increase '
  553. 'the GPU memory size. For example, if you have one 24 GB '
  554. 'GPU and set this to 10, virtually you can think of it as '
  555. 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
  556. 'which requires at least 26GB GPU memory. Note that this '
  557. 'requires fast CPU-GPU interconnect, as part of the model is'
  558. 'loaded from CPU memory to GPU memory on the fly in each '
  559. 'model forward pass.')
  560. # Scheduler Options
  561. parser.add_argument("--use-v2-block-manager",
  562. action="store_true",
  563. help="Category: Scheduler Options\n"
  564. "Use the v2 block manager.")
  565. parser.add_argument(
  566. "--scheduler-delay-factor",
  567. "-sdf",
  568. type=float,
  569. default=EngineArgs.scheduler_delay_factor,
  570. help="Category: Scheduler Options\n"
  571. "Apply a delay (of delay factor multiplied by previous "
  572. "prompt latency) before scheduling next prompt.")
  573. parser.add_argument(
  574. "--enable-chunked-prefill",
  575. action=StoreBoolean,
  576. default=EngineArgs.enable_chunked_prefill,
  577. nargs="?",
  578. const="True",
  579. help="Category: Scheduler Options\n"
  580. "If True, the prefill requests can be chunked based on the "
  581. "max_num_batched_tokens.")
  582. parser.add_argument(
  583. '--guided-decoding-backend',
  584. type=str,
  585. default='outlines',
  586. choices=['outlines', 'lm-format-enforcer'],
  587. help='Category: Scheduler Options\n'
  588. 'Which engine will be used for guided decoding'
  589. ' (JSON schema / regex etc) by default. Currently support '
  590. 'https://github.com/outlines-dev/outlines and '
  591. 'https://github.com/noamgat/lm-format-enforcer.'
  592. ' Can be overridden per request via guided_decoding_backend'
  593. ' parameter.')
  594. parser.add_argument(
  595. "--max-num-batched-tokens",
  596. type=int,
  597. default=EngineArgs.max_num_batched_tokens,
  598. help="Category: KV Cache Options\n"
  599. "maximum number of batched tokens per "
  600. "iteration",
  601. )
  602. parser.add_argument(
  603. "--max-num-seqs",
  604. type=int,
  605. default=EngineArgs.max_num_seqs,
  606. help="Category: API Options\n"
  607. "maximum number of sequences per iteration",
  608. )
  609. parser.add_argument('--num-scheduler-steps',
  610. type=int,
  611. default=1,
  612. help=('Maximum number of forward steps per '
  613. 'scheduler call.'))
  614. # Speculative Decoding Options
  615. parser.add_argument("--num-lookahead-slots",
  616. type=int,
  617. default=EngineArgs.num_lookahead_slots,
  618. help="Category: Speculative Decoding Options\n"
  619. "Experimental scheduling config necessary for "
  620. "speculative decoding. This will be replaced by "
  621. "speculative decoding config in the future; it is "
  622. "present for testing purposes until then.")
  623. parser.add_argument(
  624. "--speculative-model",
  625. type=str,
  626. default=EngineArgs.speculative_model,
  627. help="Category: Speculative Decoding Options\n"
  628. "The name of the draft model to be used in speculative decoding.")
  629. # Quantization settings for speculative model.
  630. parser.add_argument(
  631. '--speculative-model-quantization',
  632. type=str,
  633. choices=[*QUANTIZATION_METHODS, None],
  634. default=EngineArgs.speculative_model_quantization,
  635. help='Method used to quantize the weights of speculative model.'
  636. 'If None, we first check the `quantization_config` '
  637. 'attribute in the model config file. If that is '
  638. 'None, we assume the model weights are not '
  639. 'quantized and use `dtype` to determine the data '
  640. 'type of the weights.')
  641. parser.add_argument("--num-speculative-tokens",
  642. type=int,
  643. default=EngineArgs.num_speculative_tokens,
  644. help="Category: Speculative Decoding Options\n"
  645. "The number of speculative tokens to sample from "
  646. "the draft model in speculative decoding")
  647. parser.add_argument(
  648. "--speculative-max-model-len",
  649. type=str,
  650. default=EngineArgs.speculative_max_model_len,
  651. help="Category: Speculative Decoding Options\n"
  652. "The maximum sequence length supported by the "
  653. "draft model. Sequences over this length will skip "
  654. "speculation.")
  655. parser.add_argument(
  656. "--ngram-prompt-lookup-max",
  657. type=int,
  658. default=EngineArgs.ngram_prompt_lookup_max,
  659. help="Category: Speculative Decoding Options\n"
  660. "Max size of window for ngram prompt lookup in speculative "
  661. "decoding.")
  662. parser.add_argument(
  663. "--ngram-prompt-lookup-min",
  664. type=int,
  665. default=EngineArgs.ngram_prompt_lookup_min,
  666. help="Category: Speculative Decoding Options\n"
  667. "Min size of window for ngram prompt lookup in speculative "
  668. "decoding.")
  669. parser.add_argument(
  670. "--speculative-draft-tensor-parallel-size",
  671. "-spec-draft-tp",
  672. type=int,
  673. default=EngineArgs.speculative_draft_tensor_parallel_size,
  674. help="Category: Speculative Decoding Options\n"
  675. "Number of tensor parallel replicas for "
  676. "the draft model in speculative decoding.")
  677. parser.add_argument(
  678. "--speculative-disable-by-batch-size",
  679. type=int,
  680. default=EngineArgs.speculative_disable_by_batch_size,
  681. help="Category: Speculative Decoding Options\n"
  682. "Disable speculative decoding for new incoming requests "
  683. "if the number of enqueue requests is larger than this value.")
  684. parser.add_argument(
  685. '--spec-decoding-acceptance-method',
  686. type=str,
  687. default=EngineArgs.spec_decoding_acceptance_method,
  688. choices=['rejection_sampler', 'typical_acceptance_sampler'],
  689. help='Category: Speculative Decoding Options\n'
  690. 'Specify the acceptance method to use during draft token '
  691. 'verification in speculative decoding. Two types of acceptance '
  692. 'routines are supported: '
  693. '1) RejectionSampler which does not allow changing the '
  694. 'acceptance rate of draft tokens, '
  695. '2) TypicalAcceptanceSampler which is configurable, allowing for '
  696. 'a higher acceptance rate at the cost of lower quality, '
  697. 'and vice versa.')
  698. parser.add_argument(
  699. '--typical-acceptance-sampler-posterior-threshold',
  700. type=float,
  701. default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
  702. help='Category: Speculative Decoding Options\n'
  703. 'Set the lower bound threshold for the posterior '
  704. 'probability of a token to be accepted. This threshold is '
  705. 'used by the TypicalAcceptanceSampler to make sampling decisions '
  706. 'during speculative decoding. Defaults to 0.09')
  707. parser.add_argument(
  708. '--typical-acceptance-sampler-posterior-alpha',
  709. type=float,
  710. default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
  711. help='Category: Speculative Decoding Options\n'
  712. 'A scaling factor for the entropy-based threshold for token '
  713. 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
  714. 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
  715. 'i.e. 0.3')
  716. parser.add_argument(
  717. '--disable-logprobs-during-spec-decoding',
  718. type=bool,
  719. default=EngineArgs.disable_logprobs_during_spec_decoding,
  720. help='Category: Speculative Decoding Options\n'
  721. 'If set to True, token log probabilities are not returned '
  722. 'during speculative decoding. If set to False, log probabilities '
  723. 'are returned according to the settings in SamplingParams. If '
  724. 'not specified, it defaults to True. Disabling log probabilities '
  725. 'during speculative decoding reduces latency by skipping logprob '
  726. 'calculation in proposal sampling, target sampling, and after '
  727. 'accepted tokens are determined.')
  728. # Adapter Options
  729. parser.add_argument(
  730. "--enable-lora",
  731. action="store_true",
  732. help="Category: Adapter Options\n"
  733. "If True, enable handling of LoRA adapters.",
  734. )
  735. parser.add_argument(
  736. "--max-loras",
  737. type=int,
  738. default=EngineArgs.max_loras,
  739. help="Category: Adapter Options\n"
  740. "Max number of LoRAs in a single batch.",
  741. )
  742. parser.add_argument(
  743. "--max-lora-rank",
  744. type=int,
  745. default=EngineArgs.max_lora_rank,
  746. help="Category: Adapter Options\n"
  747. "Max LoRA rank.",
  748. )
  749. parser.add_argument(
  750. "--lora-extra-vocab-size",
  751. type=int,
  752. default=EngineArgs.lora_extra_vocab_size,
  753. help=("Category: Adapter Options\n"
  754. "Maximum size of extra vocabulary that can be "
  755. "present in a LoRA adapter (added to the base "
  756. "model vocabulary)."),
  757. )
  758. parser.add_argument(
  759. "--lora-dtype",
  760. type=str,
  761. default=EngineArgs.lora_dtype,
  762. choices=["auto", "float16", "bfloat16", "float32"],
  763. help=("Category: Adapter Options\n"
  764. "Data type for LoRA. If auto, will default to "
  765. "base model dtype."),
  766. )
  767. parser.add_argument(
  768. "--max-cpu-loras",
  769. type=int,
  770. default=EngineArgs.max_cpu_loras,
  771. help=("Category: Adapter Options\n"
  772. "Maximum number of LoRAs to store in CPU memory. "
  773. "Must be >= than max_num_seqs. "
  774. "Defaults to max_num_seqs."),
  775. )
  776. parser.add_argument(
  777. "--long-lora-scaling-factors",
  778. type=str,
  779. default=EngineArgs.long_lora_scaling_factors,
  780. help=("Category: Adapter Options\n"
  781. "Specify multiple scaling factors (which can "
  782. "be different from base model scaling factor "
  783. "- see eg. Long LoRA) to allow for multiple "
  784. "LoRA adapters trained with those scaling "
  785. "factors to be used at the same time. If not "
  786. "specified, only adapters trained with the "
  787. "base model scaling factor are allowed."))
  788. parser.add_argument(
  789. "--fully-sharded-loras",
  790. action='store_true',
  791. help=("Category: Adapter Options\n"
  792. "By default, only half of the LoRA computation is sharded "
  793. "with tensor parallelism. Enabling this will use the fully "
  794. "sharded layers. At high sequence length, max rank or "
  795. "tensor parallel size, this is likely faster."))
  796. parser.add_argument("--qlora-adapter-name-or-path",
  797. type=str,
  798. default=None,
  799. help="Category: Adapter Options\n"
  800. "Name or path of the LoRA adapter to use.")
  801. parser.add_argument('--enable-prompt-adapter',
  802. action='store_true',
  803. help='Category: Adapter Options\n'
  804. 'If True, enable handling of PromptAdapters.')
  805. parser.add_argument('--max-prompt-adapters',
  806. type=int,
  807. default=EngineArgs.max_prompt_adapters,
  808. help='Category: Adapter Options\n'
  809. 'Max number of PromptAdapters in a batch.')
  810. parser.add_argument('--max-prompt-adapter-token',
  811. type=int,
  812. default=EngineArgs.max_prompt_adapter_token,
  813. help='Category: Adapter Options\n'
  814. 'Max number of PromptAdapters tokens')
  815. # Log Options
  816. parser.add_argument(
  817. "--disable-log-stats",
  818. action="store_true",
  819. help="Category: Log Options\n"
  820. "disable logging statistics",
  821. )
  822. return parser
  823. @classmethod
  824. def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
  825. # Get the list of attributes of this dataclass.
  826. attrs = [attr.name for attr in dataclasses.fields(cls)]
  827. # Set the attributes from the parsed arguments.
  828. engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
  829. return engine_args
  830. def create_engine_config(self, ) -> EngineConfig:
  831. # gguf file needs a specific model loader and doesn't use hf_repo
  832. if check_gguf_file(self.model):
  833. self.quantization = self.load_format = "gguf"
  834. # bitsandbytes quantization needs a specific model loader
  835. # so we make sure the quant method and the load format are consistent
  836. if (self.quantization == "bitsandbytes" or
  837. self.qlora_adapter_name_or_path is not None) and \
  838. self.load_format != "bitsandbytes":
  839. raise ValueError(
  840. "BitsAndBytes quantization and QLoRA adapter only support "
  841. f"'bitsandbytes' load format, but got {self.load_format}")
  842. if (self.load_format == "bitsandbytes" or
  843. self.qlora_adapter_name_or_path is not None) and \
  844. self.quantization != "bitsandbytes":
  845. raise ValueError(
  846. "BitsAndBytes load format and QLoRA adapter only support "
  847. f"'bitsandbytes' quantization, but got {self.quantization}")
  848. assert self.cpu_offload_gb >= 0, (
  849. "CPU offload space must be non-negative"
  850. f", but got {self.cpu_offload_gb}")
  851. device_config = DeviceConfig(device=self.device)
  852. model_config = ModelConfig(
  853. model=self.model,
  854. tokenizer=self.tokenizer,
  855. tokenizer_mode=self.tokenizer_mode,
  856. trust_remote_code=self.trust_remote_code,
  857. dtype=self.dtype,
  858. seed=self.seed,
  859. revision=self.revision,
  860. code_revision=self.code_revision,
  861. rope_scaling=self.rope_scaling,
  862. rope_theta=self.rope_theta,
  863. tokenizer_revision=self.tokenizer_revision,
  864. max_model_len=self.max_model_len,
  865. quantization=self.quantization,
  866. deepspeed_fp_bits=self.deepspeed_fp_bits,
  867. quantization_param_path=self.quantization_param_path,
  868. enforce_eager=self.enforce_eager,
  869. max_context_len_to_capture=self.max_context_len_to_capture,
  870. max_seq_len_to_capture=self.max_seq_len_to_capture,
  871. max_logprobs=self.max_logprobs,
  872. disable_sliding_window=self.disable_sliding_window,
  873. skip_tokenizer_init=self.skip_tokenizer_init,
  874. served_model_name=self.served_model_name,
  875. limit_mm_per_prompt=self.limit_mm_per_prompt,
  876. )
  877. cache_config = CacheConfig(
  878. block_size=self.block_size,
  879. gpu_memory_utilization=self.gpu_memory_utilization,
  880. swap_space=self.swap_space,
  881. cache_dtype=self.kv_cache_dtype,
  882. is_attention_free=model_config.is_attention_free(),
  883. num_gpu_blocks_override=self.num_gpu_blocks_override,
  884. sliding_window=model_config.get_sliding_window(),
  885. enable_prefix_caching=self.enable_prefix_caching,
  886. cpu_offload_gb=self.cpu_offload_gb,
  887. )
  888. parallel_config = ParallelConfig(
  889. pipeline_parallel_size=self.pipeline_parallel_size,
  890. tensor_parallel_size=self.tensor_parallel_size,
  891. worker_use_ray=self.worker_use_ray,
  892. max_parallel_loading_workers=self.max_parallel_loading_workers,
  893. disable_custom_all_reduce=self.disable_custom_all_reduce,
  894. tokenizer_pool_config=TokenizerPoolConfig.create_config(
  895. tokenizer_pool_size=self.tokenizer_pool_size,
  896. tokenizer_pool_type=self.tokenizer_pool_type,
  897. tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
  898. ),
  899. ray_workers_use_nsight=self.ray_workers_use_nsight,
  900. distributed_executor_backend=self.distributed_executor_backend)
  901. max_model_len = model_config.max_model_len
  902. use_long_context = max_model_len > 32768
  903. if self.enable_chunked_prefill is None:
  904. # If not explicitly set, enable chunked prefill by default for
  905. # long context (> 32K) models. This is to avoid OOM errors in the
  906. # initial memory profiling phase.
  907. if use_long_context:
  908. is_gpu = device_config.device_type == "cuda"
  909. use_sliding_window = (model_config.get_sliding_window()
  910. is not None)
  911. use_spec_decode = self.speculative_model is not None
  912. has_seqlen_agnostic_layers = (
  913. model_config.contains_seqlen_agnostic_layers(
  914. parallel_config))
  915. if (is_gpu and not use_sliding_window and not use_spec_decode
  916. and not self.enable_lora
  917. and not self.enable_prompt_adapter
  918. and not self.enable_prefix_caching
  919. and not has_seqlen_agnostic_layers):
  920. self.enable_chunked_prefill = True
  921. logger.warning(
  922. "Chunked prefill is enabled by default for models with "
  923. "max_model_len > 32K. Currently, chunked prefill might "
  924. "not work with some features or models. If you "
  925. "encounter any issues, please disable chunked prefill "
  926. "by setting --enable-chunked-prefill=False.")
  927. if self.enable_chunked_prefill is None:
  928. self.enable_chunked_prefill = False
  929. if not self.enable_chunked_prefill and use_long_context:
  930. logger.warning(
  931. f"The model has a long context length ({max_model_len}). "
  932. "This may cause OOM errors during the initial memory "
  933. "profiling phase, or result in low performance due to small "
  934. "KV cache space. Consider setting --max-model-len to a "
  935. "smaller value.")
  936. speculative_config = SpeculativeConfig.maybe_create_spec_config(
  937. target_model_config=model_config,
  938. target_parallel_config=parallel_config,
  939. target_dtype=self.dtype,
  940. speculative_model=self.speculative_model,
  941. speculative_model_quantization = \
  942. self.speculative_model_quantization,
  943. speculative_draft_tensor_parallel_size=self.
  944. speculative_draft_tensor_parallel_size,
  945. num_speculative_tokens=self.num_speculative_tokens,
  946. speculative_disable_by_batch_size=self.
  947. speculative_disable_by_batch_size,
  948. speculative_max_model_len=self.speculative_max_model_len,
  949. enable_chunked_prefill=self.enable_chunked_prefill,
  950. use_v2_block_manager=self.use_v2_block_manager,
  951. disable_log_stats=self.disable_log_stats,
  952. ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
  953. ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
  954. draft_token_acceptance_method=\
  955. self.spec_decoding_acceptance_method,
  956. typical_acceptance_sampler_posterior_threshold=self.
  957. typical_acceptance_sampler_posterior_threshold,
  958. typical_acceptance_sampler_posterior_alpha=self.
  959. typical_acceptance_sampler_posterior_alpha,
  960. disable_logprobs=self.disable_logprobs_during_spec_decoding,
  961. )
  962. if self.num_scheduler_steps > 1:
  963. raise NotImplementedError("Multi-step is not yet supported.")
  964. if speculative_config is not None:
  965. raise ValueError("Speculative decoding is not supported with "
  966. "multi-step (--num-scheduler-steps > 1)")
  967. if self.enable_chunked_prefill:
  968. raise ValueError("Chunked prefill is not supported with "
  969. "multi-step (--num-scheduler-steps > 1)")
  970. # make sure num_lookahead_slots is set the higher value depending on
  971. # if we are using speculative decoding or multi-step
  972. num_lookahead_slots = max(self.num_lookahead_slots,
  973. self.num_scheduler_steps - 1)
  974. num_lookahead_slots = num_lookahead_slots \
  975. if speculative_config is None \
  976. else speculative_config.num_lookahead_slots
  977. scheduler_config = SchedulerConfig(
  978. max_num_batched_tokens=self.max_num_batched_tokens,
  979. max_num_seqs=self.max_num_seqs,
  980. max_model_len=model_config.max_model_len,
  981. is_attention_free=model_config.is_attention_free(),
  982. use_v2_block_manager=self.use_v2_block_manager,
  983. num_lookahead_slots=num_lookahead_slots,
  984. delay_factor=self.scheduler_delay_factor,
  985. enable_chunked_prefill=self.enable_chunked_prefill,
  986. embedding_mode=model_config.embedding_mode,
  987. preemption_mode=self.preemption_mode,
  988. num_scheduler_steps=self.num_scheduler_steps,
  989. )
  990. lora_config = LoRAConfig(
  991. max_lora_rank=self.max_lora_rank,
  992. max_loras=self.max_loras,
  993. fully_sharded_loras=self.fully_sharded_loras,
  994. lora_extra_vocab_size=self.lora_extra_vocab_size,
  995. long_lora_scaling_factors=self.long_lora_scaling_factors,
  996. lora_dtype=self.lora_dtype,
  997. max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
  998. and self.max_cpu_loras > 0 else None) if self.enable_lora else None
  999. if self.qlora_adapter_name_or_path is not None and \
  1000. self.qlora_adapter_name_or_path != "":
  1001. if self.model_loader_extra_config is None:
  1002. self.model_loader_extra_config = {}
  1003. self.model_loader_extra_config[
  1004. "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
  1005. load_config = LoadConfig(
  1006. load_format=self.load_format,
  1007. download_dir=self.download_dir,
  1008. model_loader_extra_config=self.model_loader_extra_config,
  1009. ignore_patterns=self.ignore_patterns)
  1010. prompt_adapter_config = PromptAdapterConfig(
  1011. max_prompt_adapters=self.max_prompt_adapters,
  1012. max_prompt_adapter_token=self.max_prompt_adapter_token) \
  1013. if self.enable_prompt_adapter else None
  1014. decoding_config = DecodingConfig(
  1015. guided_decoding_backend=self.guided_decoding_backend)
  1016. if (model_config.get_sliding_window() is not None
  1017. and scheduler_config.chunked_prefill_enabled
  1018. and not scheduler_config.use_v2_block_manager):
  1019. raise ValueError(
  1020. "Chunked prefill is not supported with sliding window. "
  1021. "Set --disable-sliding-window to disable sliding window.")
  1022. return EngineConfig(model_config=model_config,
  1023. cache_config=cache_config,
  1024. parallel_config=parallel_config,
  1025. scheduler_config=scheduler_config,
  1026. device_config=device_config,
  1027. lora_config=lora_config,
  1028. speculative_config=speculative_config,
  1029. load_config=load_config,
  1030. decoding_config=decoding_config,
  1031. prompt_adapter_config=prompt_adapter_config)
  1032. @dataclass
  1033. class AsyncEngineArgs(EngineArgs):
  1034. """Arguments for asynchronous Aphrodite engine."""
  1035. engine_use_ray: bool = False
  1036. disable_log_requests: bool = False
  1037. uvloop: bool = False
  1038. @staticmethod
  1039. def add_cli_args(parser: FlexibleArgumentParser,
  1040. async_args_only: bool = False) -> FlexibleArgumentParser:
  1041. if not async_args_only:
  1042. parser = EngineArgs.add_cli_args(parser)
  1043. parser.add_argument('--engine-use-ray',
  1044. action='store_true',
  1045. help='Use Ray to start the LLM engine in a '
  1046. 'separate process as the server process.')
  1047. parser.add_argument('--disable-log-requests',
  1048. action='store_true',
  1049. help='Disable logging requests.')
  1050. parser.add_argument(
  1051. "--uvloop",
  1052. action="store_true",
  1053. help="Use the Uvloop asyncio event loop to possibly increase "
  1054. "performance")
  1055. return parser
  1056. class StoreBoolean(argparse.Action):
  1057. def __call__(self, parser, namespace, values, option_string=None):
  1058. if values.lower() == "true":
  1059. setattr(namespace, self.dest, True)
  1060. elif values.lower() == "false":
  1061. setattr(namespace, self.dest, False)
  1062. else:
  1063. raise ValueError(f"Invalid boolean value: {values}. "
  1064. "Expected 'true' or 'false'.")