1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186 |
- import argparse
- import dataclasses
- import json
- from dataclasses import dataclass
- from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
- Type, Union)
- from loguru import logger
- import aphrodite.common.envs as envs
- from aphrodite.common.config import (CacheConfig, ConfigFormat, DecodingConfig,
- DeviceConfig, EngineConfig, LoadConfig,
- LoadFormat, LoRAConfig, ModelConfig,
- ParallelConfig, PromptAdapterConfig,
- SchedulerConfig, SpeculativeConfig,
- TokenizerPoolConfig)
- from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
- from aphrodite.executor.executor_base import ExecutorBase
- from aphrodite.quantization import QUANTIZATION_METHODS
- from aphrodite.transformers_utils.utils import check_gguf_file
- from aphrodite.triton_utils import HAS_TRITON
- if TYPE_CHECKING:
- from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
- APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
- DEVICE_OPTIONS = [
- "auto",
- "cuda",
- "neuron",
- "cpu",
- "openvino",
- "tpu",
- "xpu",
- ]
- def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
- if len(val) == 0:
- return None
- out_dict: Dict[str, int] = {}
- for item in val.split(","):
- try:
- key, value = item.split("=")
- except TypeError as exc:
- msg = "Each item should be in the form KEY=VALUE"
- raise ValueError(msg) from exc
- try:
- out_dict[key] = int(value)
- except ValueError as exc:
- msg = f"Failed to parse value of item {key}={value}"
- raise ValueError(msg) from exc
- return out_dict
- @dataclass
- class EngineArgs:
- """Arguments for Aphrodite engine."""
- # Model Options
- model: str
- seed: int = 0
- served_model_name: Optional[Union[str, List[str]]] = None
- tokenizer: Optional[str] = None
- revision: Optional[str] = None
- code_revision: Optional[str] = None
- tokenizer_revision: Optional[str] = None
- tokenizer_mode: str = "auto"
- trust_remote_code: bool = False
- download_dir: Optional[str] = None
- max_model_len: Optional[int] = None
- max_context_len_to_capture: Optional[int] = None
- max_seq_len_to_capture: Optional[int] = None
- rope_scaling: Optional[dict] = None
- rope_theta: Optional[float] = None
- model_loader_extra_config: Optional[dict] = None
- enforce_eager: Optional[bool] = None
- skip_tokenizer_init: bool = False
- tokenizer_pool_size: int = 0
- # Note: Specifying a tokenizer pool by passing a class
- # is intended for expert use only. The API may change without
- # notice.
- tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
- tokenizer_pool_extra_config: Optional[dict] = None
- limit_mm_per_prompt: Optional[Mapping[str, int]] = None
- max_logprobs: int = 10 # OpenAI default is 5, setting to 10 because ST
- # Device Options
- device: str = "auto"
- # Load Options
- load_format: str = "auto"
- config_format: str = "auto"
- dtype: str = "auto"
- ignore_patterns: Optional[Union[str, List[str]]] = None
- # Parallel Options
- worker_use_ray: Optional[bool] = False
- tensor_parallel_size: int = 1
- pipeline_parallel_size: int = 1
- ray_workers_use_nsight: bool = False
- disable_custom_all_reduce: bool = False
- # Note: Specifying a custom executor backend by passing a class
- # is intended for expert use only. The API may change without
- # notice.
- distributed_executor_backend: Optional[Union[str,
- Type[ExecutorBase]]] = None
- max_parallel_loading_workers: Optional[int] = None
- # Quantization Options
- quantization: Optional[str] = None
- quantization_param_path: Optional[str] = None
- preemption_mode: Optional[str] = None
- deepspeed_fp_bits: Optional[int] = None
- quant_llm_fp_bits: Optional[int] = None
- quant_llm_exp_bits: Optional[int] = None
- # Cache Options
- kv_cache_dtype: str = "auto"
- block_size: int = 16
- enable_prefix_caching: Optional[bool] = False
- num_gpu_blocks_override: Optional[int] = None
- disable_sliding_window: bool = False
- gpu_memory_utilization: float = 0.90
- swap_space: float = 4 # GiB
- cpu_offload_gb: float = 0 # GiB
- # Scheduler Options
- use_v2_block_manager: bool = False
- scheduler_delay_factor: float = 0.0
- enable_chunked_prefill: Optional[bool] = None
- guided_decoding_backend: str = 'lm-format-enforcer'
- max_num_batched_tokens: Optional[int] = None
- max_num_seqs: int = 256
- num_scheduler_steps: int = 1
- single_user_mode: bool = False
- # Speculative Decoding Options
- num_lookahead_slots: int = 0
- speculative_model: Optional[str] = None
- speculative_model_quantization: Optional[str] = None
- num_speculative_tokens: Optional[int] = None
- speculative_max_model_len: Optional[int] = None
- ngram_prompt_lookup_max: Optional[int] = None
- ngram_prompt_lookup_min: Optional[int] = None
- speculative_draft_tensor_parallel_size: Optional[int] = None
- speculative_disable_by_batch_size: Optional[int] = None
- spec_decoding_acceptance_method: str = 'rejection_sampler'
- typical_acceptance_sampler_posterior_threshold: Optional[float] = None
- typical_acceptance_sampler_posterior_alpha: Optional[float] = None
- disable_logprobs_during_spec_decoding: Optional[bool] = None
- # Adapter Options
- enable_lora: bool = False
- max_loras: int = 1
- max_lora_rank: int = 16
- lora_extra_vocab_size: int = 256
- lora_dtype: str = "auto"
- max_cpu_loras: Optional[int] = None
- long_lora_scaling_factors: Optional[Tuple[float]] = None
- fully_sharded_loras: bool = False
- qlora_adapter_name_or_path: Optional[str] = None
- enable_prompt_adapter: bool = False
- max_prompt_adapters: int = 1
- max_prompt_adapter_token: int = 0
- # Log Options
- disable_log_stats: bool = False
- disable_async_output_proc: bool = False
- override_neuron_config: Optional[Dict[str, Any]] = None
- def __post_init__(self):
- if self.tokenizer is None:
- self.tokenizer = self.model
- if is_cpu():
- self.distributed_executor_backend = None
- @staticmethod
- def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
- """Shared CLI arguments for the Aphrodite engine."""
- # Model Options
- parser.add_argument(
- "--model",
- type=str,
- default="EleutherAI/pythia-70m-deduped",
- help="Category: Model Options\n"
- "name or path of the huggingface model to use",
- )
- parser.add_argument("--seed",
- type=int,
- default=EngineArgs.seed,
- help="Category: Model Options\n"
- "random seed")
- parser.add_argument(
- "--served-model-name",
- nargs="+",
- type=str,
- default=None,
- help="Category: API Options\n"
- "The model name(s) used in the API. If multiple "
- "names are provided, the server will respond to any "
- "of the provided names. The model name in the model "
- "field of a response will be the first name in this "
- "list. If not specified, the model name will be the "
- "same as the `--model` argument. Noted that this name(s)"
- "will also be used in `model_name` tag content of "
- "prometheus metrics, if multiple names provided, metrics"
- "tag will take the first one.")
- parser.add_argument(
- "--tokenizer",
- type=str,
- default=EngineArgs.tokenizer,
- help="Category: Model Options\n"
- "name or path of the huggingface tokenizer to use",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- help="Category: Model Options\n"
- "the specific model version to use. It can be a branch "
- "name, a tag name, or a commit id. If unspecified, will use "
- "the default version.",
- )
- parser.add_argument(
- "--code-revision",
- type=str,
- default=None,
- help="Category: Model Options\n"
- "the specific revision to use for the model code on "
- "Hugging Face Hub. It can be a branch name, a tag name, or a "
- "commit id. If unspecified, will use the default version.",
- )
- parser.add_argument(
- "--tokenizer-revision",
- type=str,
- default=None,
- help="Category: Model Options\n"
- "the specific tokenizer version to use. It can be a branch "
- "name, a tag name, or a commit id. If unspecified, will use "
- "the default version.",
- )
- parser.add_argument(
- "--tokenizer-mode",
- type=str,
- default=EngineArgs.tokenizer_mode,
- choices=['auto', 'slow', 'mistral'],
- help='The tokenizer mode.\n\n* "auto" will use the '
- 'fast tokenizer if available.\n* "slow" will '
- 'always use the slow tokenizer. \n* '
- '"mistral" will always use the `mistral_common` tokenizer.')
- parser.add_argument(
- "--trust-remote-code",
- action="store_true",
- help="Category: Model Options\n"
- "trust remote code from huggingface",
- )
- parser.add_argument(
- "--download-dir",
- type=str,
- default=EngineArgs.download_dir,
- help="Category: Model Options\n"
- "directory to download and load the weights, "
- "default to the default cache dir of "
- "huggingface",
- )
- parser.add_argument(
- "--max-model-len",
- type=int,
- default=EngineArgs.max_model_len,
- help="Category: Model Options\n"
- "model context length. If unspecified, "
- "will be automatically derived from the model.",
- )
- parser.add_argument("--max-context-len-to-capture",
- type=int,
- default=EngineArgs.max_context_len_to_capture,
- help="Category: Model Options\n"
- "Maximum context length covered by CUDA "
- "graphs. When a sequence has context length "
- "larger than this, we fall back to eager mode. "
- "(DEPRECATED. Use --max-seq_len-to-capture instead"
- ")")
- parser.add_argument("--max-seq-len-to-capture",
- type=int,
- default=EngineArgs.max_seq_len_to_capture,
- help='Maximum sequence length covered by CUDA '
- 'graphs. When a sequence has context length '
- 'larger than this, we fall back to eager mode. '
- 'Additionally for encoder-decoder models, if the '
- 'sequence length of the encoder input is larger '
- 'than this, we fall back to the eager mode.')
- parser.add_argument('--rope-scaling',
- default=None,
- type=json.loads,
- help='Category: Model Options\n'
- 'RoPE scaling configuration in JSON format. '
- 'For example, {"type":"dynamic","factor":2.0}')
- parser.add_argument('--rope-theta',
- default=None,
- type=float,
- help='Category: Model Options\n'
- 'RoPE theta. Use with `rope_scaling`. In '
- 'some cases, changing the RoPE theta improves the '
- 'performance of the scaled model.')
- parser.add_argument("--model-loader-extra-config",
- type=str,
- default=EngineArgs.model_loader_extra_config,
- help="Category: Model Options\n"
- "Extra config for model loader. "
- "This will be passed to the model loader "
- "corresponding to the chosen load_format. "
- "This should be a JSON string that will be "
- "parsed into a dictionary.")
- parser.add_argument(
- "--enforce-eager",
- action=StoreBoolean,
- default=EngineArgs.enforce_eager,
- nargs="?",
- const="True",
- help="Category: Model Options\n"
- "Always use eager-mode PyTorch. If False, "
- "will use eager mode and CUDA graph in hybrid "
- "for maximal performance and flexibility.",
- )
- parser.add_argument("--skip-tokenizer-init",
- action="store_true",
- help="Category: Model Options\n"
- "Skip initialization of tokenizer and detokenizer")
- parser.add_argument("--tokenizer-pool-size",
- type=int,
- default=EngineArgs.tokenizer_pool_size,
- help="Category: Model Options\n"
- "Size of tokenizer pool to use for "
- "asynchronous tokenization. If 0, will "
- "use synchronous tokenization.")
- parser.add_argument("--tokenizer-pool-type",
- type=str,
- default=EngineArgs.tokenizer_pool_type,
- help="Category: Model Options\n"
- "The type of tokenizer pool to use for "
- "asynchronous tokenization. Ignored if "
- "tokenizer_pool_size is 0.")
- parser.add_argument("--tokenizer-pool-extra-config",
- type=str,
- default=EngineArgs.tokenizer_pool_extra_config,
- help="Category: Model Options\n"
- "Extra config for tokenizer pool. "
- "This should be a JSON string that will be "
- "parsed into a dictionary. Ignored if "
- "tokenizer_pool_size is 0.")
- # Multimodal related configs
- parser.add_argument(
- '--limit-mm-per-prompt',
- type=nullable_kvs,
- default=EngineArgs.limit_mm_per_prompt,
- # The default value is given in
- # MultiModalRegistry.init_mm_limits_per_prompt
- help=('For each multimodal plugin, limit how many '
- 'input instances to allow for each prompt. '
- 'Expects a comma-separated list of items, '
- 'e.g.: `image=16,video=2` allows a maximum of 16 '
- 'images and 2 videos per prompt. Defaults to 1 for '
- 'each modality.'))
- parser.add_argument(
- "--max-logprobs",
- type=int,
- default=EngineArgs.max_logprobs,
- help="Category: Model Options\n"
- "maximum number of log probabilities to "
- "return.",
- )
- # Device Options
- parser.add_argument(
- "--device",
- type=str,
- default=EngineArgs.device,
- choices=DEVICE_OPTIONS,
- help=("Category: Model Options\n"
- "Device to use for model execution."),
- )
- # Load Options
- parser.add_argument(
- '--load-format',
- type=str,
- default=EngineArgs.load_format,
- choices=[f.value for f in LoadFormat],
- help='Category: Model Options\n'
- 'The format of the model weights to load.\n\n'
- '* "auto" will try to load the weights in the safetensors format '
- 'and fall back to the pytorch bin format if safetensors format '
- 'is not available.\n'
- '* "pt" will load the weights in the pytorch bin format.\n'
- '* "safetensors" will load the weights in the safetensors format.\n'
- '* "npcache" will load the weights in pytorch format and store '
- 'a numpy cache to speed up the loading.\n'
- '* "dummy" will initialize the weights with random values, '
- 'which is mainly for profiling.\n'
- '* "tensorizer" will load the weights using tensorizer from '
- 'CoreWeave. See the Tensorize Aphrodite Model script in the '
- 'Examples section for more information.\n'
- '* "bitsandbytes" will load the weights using bitsandbytes '
- 'quantization.\n')
- parser.add_argument(
- '--config-format',
- default=EngineArgs.config_format,
- choices=[f.value for f in ConfigFormat],
- help='The format of the model config to load.\n\n'
- '* "auto" will try to load the config in hf format '
- 'if available else it will try to load in mistral format. '
- 'Mistral format is specific to mistral models and is not '
- 'compatible with other models.')
- parser.add_argument(
- '--dtype',
- type=str,
- default=EngineArgs.dtype,
- choices=[
- 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
- ],
- help='Category: Model Options\n'
- 'Data type for model weights and activations.\n\n'
- '* "auto" will use FP16 precision for FP32 and FP16 models, and '
- 'BF16 precision for BF16 models.\n'
- '* "half" for FP16. Recommended for AWQ quantization.\n'
- '* "float16" is the same as "half".\n'
- '* "bfloat16" for a balance between precision and range.\n'
- '* "float" is shorthand for FP32 precision.\n'
- '* "float32" for FP32 precision.')
- parser.add_argument(
- '--ignore-patterns',
- action="append",
- type=str,
- default=[],
- help="Category: Model Options\n"
- "The pattern(s) to ignore when loading the model."
- "Defaults to 'original/**/*' to avoid repeated loading of llama's "
- "checkpoints.")
- # Parallel Options
- parser.add_argument(
- '--worker-use-ray',
- action='store_true',
- help='Category: Parallel Options\n'
- 'Deprecated, use --distributed-executor-backend=ray.')
- parser.add_argument(
- "--tensor-parallel-size",
- "-tp",
- type=int,
- default=EngineArgs.tensor_parallel_size,
- help="Category: Parallel Options\n"
- "number of tensor parallel replicas, i.e. the number of GPUs "
- "to use.")
- parser.add_argument(
- "--pipeline-parallel-size",
- "-pp",
- type=int,
- default=EngineArgs.pipeline_parallel_size,
- help="Category: Parallel Options\n"
- "number of pipeline stages. Currently not supported.")
- parser.add_argument(
- "--ray-workers-use-nsight",
- action="store_true",
- help="Category: Parallel Options\n"
- "If specified, use nsight to profile ray workers",
- )
- parser.add_argument(
- "--disable-custom-all-reduce",
- action="store_true",
- default=EngineArgs.disable_custom_all_reduce,
- help="Category: Model Options\n"
- "See ParallelConfig",
- )
- parser.add_argument(
- '--distributed-executor-backend',
- choices=['ray', 'mp'],
- default=EngineArgs.distributed_executor_backend,
- help='Category: Parallel Options\n'
- 'Backend to use for distributed serving. When more than 1 GPU '
- 'is used, will be automatically set to "ray" if installed '
- 'or "mp" (multiprocessing) otherwise.')
- parser.add_argument(
- "--max-parallel-loading-workers",
- type=int,
- default=EngineArgs.max_parallel_loading_workers,
- help="Category: Parallel Options\n"
- "load model sequentially in multiple batches, "
- "to avoid RAM OOM when using tensor "
- "parallel and large models",
- )
- # Quantization Options
- parser.add_argument(
- "--quantization",
- "-q",
- type=str,
- choices=[*QUANTIZATION_METHODS, None],
- default=EngineArgs.quantization,
- help="Category: Quantization Options\n"
- "Method used to quantize the weights. If "
- "None, we first check the `quantization_config` "
- "attribute in the model config file. If that is "
- "None, we assume the model weights are not "
- "quantized and use `dtype` to determine the data "
- "type of the weights.",
- )
- parser.add_argument(
- '--quantization-param-path',
- type=str,
- default=None,
- help='Category: Quantization Options\n'
- 'Path to the JSON file containing the KV cache '
- 'scaling factors. This should generally be supplied, when '
- 'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
- 'default to 1.0, which may cause accuracy issues. '
- 'FP8_E5M2 (without scaling) is only supported on cuda version'
- 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
- 'supported for common inference criteria. ')
- parser.add_argument(
- '--preemption-mode',
- type=str,
- default=None,
- help='Category: Scheduler Options\n'
- 'If \'recompute\', the engine performs preemption by block '
- 'swapping; If \'swap\', the engine performs preemption by block '
- 'swapping.')
- parser.add_argument("--deepspeed-fp-bits",
- type=int,
- default=None,
- help="Category: Quantization Options\n"
- "Number of floating bits to use for the deepspeed "
- "quantization. Supported bits are: 4, 6, 8, 12.")
- parser.add_argument("--quant-llm-fp-bits",
- type=int,
- default=None,
- help="Category: Quantization Options\n"
- "Number of floating bits to use for the quant_llm "
- "quantization. Supported bits are: 4 to 15.")
- parser.add_argument("--quant-llm-exp-bits",
- type=int,
- default=None,
- help="Category: Quantization Options\n"
- "Number of exponent bits to use for the quant_llm "
- "quantization. Supported bits are: 1 to 5.")
- # Cache Options
- parser.add_argument(
- '--kv-cache-dtype',
- type=str,
- choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
- default=EngineArgs.kv_cache_dtype,
- help='Category: Cache Options\n'
- 'Data type for kv cache storage. If "auto", will use model '
- 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
- 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
- parser.add_argument(
- "--block-size",
- type=int,
- default=EngineArgs.block_size,
- choices=[8, 16, 32],
- help="Category: Cache Options\n"
- "token block size for contiguous chunks of "
- "tokens. This is ignored on neuron devices and "
- "set to max-model-len."
- )
- parser.add_argument(
- "--enable-prefix-caching",
- "--context-shift",
- action="store_true",
- help="Category: Cache Options\n"
- "Enable automatic prefix caching.",
- )
- parser.add_argument(
- "--num-gpu-blocks-override",
- type=int,
- default=None,
- help="Category: Cache Options Options\n"
- "If specified, ignore GPU profiling result and use this "
- "number of GPU blocks. Used for testing preemption.")
- parser.add_argument('--disable-sliding-window',
- action='store_true',
- help='Category: KV Cache Options\n'
- 'Disables sliding window, '
- 'capping to sliding window size')
- parser.add_argument(
- "--gpu-memory-utilization",
- "-gmu",
- type=float,
- default=EngineArgs.gpu_memory_utilization,
- help="Category: Cache Options\n"
- "The fraction of GPU memory to be used for "
- "the model executor, which can range from 0 to 1."
- "If unspecified, will use the default value of 0.9.",
- )
- parser.add_argument(
- "--swap-space",
- type=float,
- default=EngineArgs.swap_space,
- help="Category: Cache Options\n"
- "CPU swap space size (GiB) per GPU",
- )
- parser.add_argument(
- '--cpu-offload-gb',
- type=float,
- default=0,
- help='Category: Cache Options\n'
- 'The space in GiB to offload to CPU, per GPU. '
- 'Default is 0, which means no offloading. Intuitively, '
- 'this argument can be seen as a virtual way to increase '
- 'the GPU memory size. For example, if you have one 24 GB '
- 'GPU and set this to 10, virtually you can think of it as '
- 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
- 'which requires at least 26GB GPU memory. Note that this '
- 'requires fast CPU-GPU interconnect, as part of the model is'
- 'loaded from CPU memory to GPU memory on the fly in each '
- 'model forward pass.')
- # Scheduler Options
- parser.add_argument("--use-v2-block-manager",
- action="store_true",
- help="Category: Scheduler Options\n"
- "Use the v2 block manager.")
- parser.add_argument(
- "--scheduler-delay-factor",
- "-sdf",
- type=float,
- default=EngineArgs.scheduler_delay_factor,
- help="Category: Scheduler Options\n"
- "Apply a delay (of delay factor multiplied by previous "
- "prompt latency) before scheduling next prompt.")
- parser.add_argument(
- "--enable-chunked-prefill",
- action=StoreBoolean,
- default=EngineArgs.enable_chunked_prefill,
- nargs="?",
- const="True",
- help="Category: Scheduler Options\n"
- "If True, the prefill requests can be chunked based on the "
- "max_num_batched_tokens.")
- parser.add_argument(
- '--guided-decoding-backend',
- type=str,
- default='lm-format-enforcer',
- choices=['outlines', 'lm-format-enforcer'],
- help='Category: Scheduler Options\n'
- 'Which engine will be used for guided decoding'
- ' (JSON schema / regex etc) by default. Currently support '
- 'https://github.com/outlines-dev/outlines and '
- 'https://github.com/noamgat/lm-format-enforcer.'
- ' Can be overridden per request via guided_decoding_backend'
- ' parameter.')
- parser.add_argument(
- "--max-num-batched-tokens",
- type=int,
- default=EngineArgs.max_num_batched_tokens,
- help="Category: KV Cache Options\n"
- "maximum number of batched tokens per "
- "iteration",
- )
- parser.add_argument(
- "--max-num-seqs",
- type=int,
- default=EngineArgs.max_num_seqs,
- help="Category: API Options\n"
- "maximum number of sequences per iteration",
- )
- parser.add_argument('--single-user-mode',
- action='store_true',
- help='Category: API Options\n'
- 'If True, we only allocate blocks for one sequence '
- 'and use the maximum sequence length as the number '
- 'of tokens.')
- parser.add_argument('--num-scheduler-steps',
- type=int,
- default=1,
- help=('Maximum number of forward steps per '
- 'scheduler call.'))
- # Speculative Decoding Options
- parser.add_argument("--num-lookahead-slots",
- type=int,
- default=EngineArgs.num_lookahead_slots,
- help="Category: Speculative Decoding Options\n"
- "Experimental scheduling config necessary for "
- "speculative decoding. This will be replaced by "
- "speculative decoding config in the future; it is "
- "present for testing purposes until then.")
- parser.add_argument(
- "--speculative-model",
- type=str,
- default=EngineArgs.speculative_model,
- help="Category: Speculative Decoding Options\n"
- "The name of the draft model to be used in speculative decoding.")
- # Quantization settings for speculative model.
- parser.add_argument(
- '--speculative-model-quantization',
- type=str,
- choices=[*QUANTIZATION_METHODS, None],
- default=EngineArgs.speculative_model_quantization,
- help='Method used to quantize the weights of speculative model.'
- 'If None, we first check the `quantization_config` '
- 'attribute in the model config file. If that is '
- 'None, we assume the model weights are not '
- 'quantized and use `dtype` to determine the data '
- 'type of the weights.')
- parser.add_argument("--num-speculative-tokens",
- type=int,
- default=EngineArgs.num_speculative_tokens,
- help="Category: Speculative Decoding Options\n"
- "The number of speculative tokens to sample from "
- "the draft model in speculative decoding")
- parser.add_argument(
- "--speculative-max-model-len",
- type=str,
- default=EngineArgs.speculative_max_model_len,
- help="Category: Speculative Decoding Options\n"
- "The maximum sequence length supported by the "
- "draft model. Sequences over this length will skip "
- "speculation.")
- parser.add_argument(
- "--ngram-prompt-lookup-max",
- type=int,
- default=EngineArgs.ngram_prompt_lookup_max,
- help="Category: Speculative Decoding Options\n"
- "Max size of window for ngram prompt lookup in speculative "
- "decoding.")
- parser.add_argument(
- "--ngram-prompt-lookup-min",
- type=int,
- default=EngineArgs.ngram_prompt_lookup_min,
- help="Category: Speculative Decoding Options\n"
- "Min size of window for ngram prompt lookup in speculative "
- "decoding.")
- parser.add_argument(
- "--speculative-draft-tensor-parallel-size",
- "-spec-draft-tp",
- type=int,
- default=EngineArgs.speculative_draft_tensor_parallel_size,
- help="Category: Speculative Decoding Options\n"
- "Number of tensor parallel replicas for "
- "the draft model in speculative decoding.")
- parser.add_argument(
- "--speculative-disable-by-batch-size",
- type=int,
- default=EngineArgs.speculative_disable_by_batch_size,
- help="Category: Speculative Decoding Options\n"
- "Disable speculative decoding for new incoming requests "
- "if the number of enqueue requests is larger than this value.")
- parser.add_argument(
- '--spec-decoding-acceptance-method',
- type=str,
- default=EngineArgs.spec_decoding_acceptance_method,
- choices=['rejection_sampler', 'typical_acceptance_sampler'],
- help='Category: Speculative Decoding Options\n'
- 'Specify the acceptance method to use during draft token '
- 'verification in speculative decoding. Two types of acceptance '
- 'routines are supported: '
- '1) RejectionSampler which does not allow changing the '
- 'acceptance rate of draft tokens, '
- '2) TypicalAcceptanceSampler which is configurable, allowing for '
- 'a higher acceptance rate at the cost of lower quality, '
- 'and vice versa.')
- parser.add_argument(
- '--typical-acceptance-sampler-posterior-threshold',
- type=float,
- default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
- help='Category: Speculative Decoding Options\n'
- 'Set the lower bound threshold for the posterior '
- 'probability of a token to be accepted. This threshold is '
- 'used by the TypicalAcceptanceSampler to make sampling decisions '
- 'during speculative decoding. Defaults to 0.09')
- parser.add_argument(
- '--typical-acceptance-sampler-posterior-alpha',
- type=float,
- default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
- help='Category: Speculative Decoding Options\n'
- 'A scaling factor for the entropy-based threshold for token '
- 'acceptance in the TypicalAcceptanceSampler. Typically defaults '
- 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
- 'i.e. 0.3')
- parser.add_argument(
- '--disable-logprobs-during-spec-decoding',
- type=bool,
- default=EngineArgs.disable_logprobs_during_spec_decoding,
- help='Category: Speculative Decoding Options\n'
- 'If set to True, token log probabilities are not returned '
- 'during speculative decoding. If set to False, log probabilities '
- 'are returned according to the settings in SamplingParams. If '
- 'not specified, it defaults to True. Disabling log probabilities '
- 'during speculative decoding reduces latency by skipping logprob '
- 'calculation in proposal sampling, target sampling, and after '
- 'accepted tokens are determined.')
- # Adapter Options
- parser.add_argument(
- "--enable-lora",
- action="store_true",
- help="Category: Adapter Options\n"
- "If True, enable handling of LoRA adapters.",
- )
- parser.add_argument(
- "--max-loras",
- type=int,
- default=EngineArgs.max_loras,
- help="Category: Adapter Options\n"
- "Max number of LoRAs in a single batch.",
- )
- parser.add_argument(
- "--max-lora-rank",
- type=int,
- default=EngineArgs.max_lora_rank,
- help="Category: Adapter Options\n"
- "Max LoRA rank.",
- )
- parser.add_argument(
- "--lora-extra-vocab-size",
- type=int,
- default=EngineArgs.lora_extra_vocab_size,
- help=("Category: Adapter Options\n"
- "Maximum size of extra vocabulary that can be "
- "present in a LoRA adapter (added to the base "
- "model vocabulary)."),
- )
- parser.add_argument(
- "--lora-dtype",
- type=str,
- default=EngineArgs.lora_dtype,
- choices=["auto", "float16", "bfloat16", "float32"],
- help=("Category: Adapter Options\n"
- "Data type for LoRA. If auto, will default to "
- "base model dtype."),
- )
- parser.add_argument(
- "--max-cpu-loras",
- type=int,
- default=EngineArgs.max_cpu_loras,
- help=("Category: Adapter Options\n"
- "Maximum number of LoRAs to store in CPU memory. "
- "Must be >= than max_num_seqs. "
- "Defaults to max_num_seqs."),
- )
- parser.add_argument(
- "--long-lora-scaling-factors",
- type=str,
- default=EngineArgs.long_lora_scaling_factors,
- help=("Category: Adapter Options\n"
- "Specify multiple scaling factors (which can "
- "be different from base model scaling factor "
- "- see eg. Long LoRA) to allow for multiple "
- "LoRA adapters trained with those scaling "
- "factors to be used at the same time. If not "
- "specified, only adapters trained with the "
- "base model scaling factor are allowed."))
- parser.add_argument(
- "--fully-sharded-loras",
- action='store_true',
- help=("Category: Adapter Options\n"
- "By default, only half of the LoRA computation is sharded "
- "with tensor parallelism. Enabling this will use the fully "
- "sharded layers. At high sequence length, max rank or "
- "tensor parallel size, this is likely faster."))
- parser.add_argument("--qlora-adapter-name-or-path",
- type=str,
- default=None,
- help="Category: Adapter Options\n"
- "Name or path of the LoRA adapter to use.")
- parser.add_argument('--enable-prompt-adapter',
- action='store_true',
- help='Category: Adapter Options\n'
- 'If True, enable handling of PromptAdapters.')
- parser.add_argument('--max-prompt-adapters',
- type=int,
- default=EngineArgs.max_prompt_adapters,
- help='Category: Adapter Options\n'
- 'Max number of PromptAdapters in a batch.')
- parser.add_argument('--max-prompt-adapter-token',
- type=int,
- default=EngineArgs.max_prompt_adapter_token,
- help='Category: Adapter Options\n'
- 'Max number of PromptAdapters tokens')
- # Log Options
- parser.add_argument(
- "--disable-log-stats",
- action="store_true",
- help="Category: Log Options\n"
- "disable logging statistics",
- )
- parser.add_argument(
- "--disable-async-output-proc",
- action="store_true",
- default=EngineArgs.disable_async_output_proc,
- help="Disable async output processing. THis may result in "
- "lower performance.")
- parser.add_argument(
- '--override-neuron-config',
- type=lambda configs: {
- str(key): value
- for key, value in
- (config.split(':') for config in configs.split(','))
- },
- default=None,
- help="override or set neuron device configuration.")
- return parser
- @classmethod
- def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
- # Get the list of attributes of this dataclass.
- attrs = [attr.name for attr in dataclasses.fields(cls)]
- # Set the attributes from the parsed arguments.
- engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
- return engine_args
- def create_engine_config(self, ) -> EngineConfig:
- # gguf file needs a specific model loader and doesn't use hf_repo
- if check_gguf_file(self.model):
- self.quantization = self.load_format = "gguf"
- # bitsandbytes quantization needs a specific model loader
- # so we make sure the quant method and the load format are consistent
- if (self.quantization == "bitsandbytes" or
- self.qlora_adapter_name_or_path is not None) and \
- self.load_format != "bitsandbytes":
- raise ValueError(
- "BitsAndBytes quantization and QLoRA adapter only support "
- f"'bitsandbytes' load format, but got {self.load_format}")
- if (self.load_format == "bitsandbytes" or
- self.qlora_adapter_name_or_path is not None) and \
- self.quantization != "bitsandbytes":
- raise ValueError(
- "BitsAndBytes load format and QLoRA adapter only support "
- f"'bitsandbytes' quantization, but got {self.quantization}")
- assert self.cpu_offload_gb >= 0, (
- "CPU offload space must be non-negative"
- f", but got {self.cpu_offload_gb}")
- device_config = DeviceConfig(device=self.device)
- model_config = ModelConfig(
- model=self.model,
- tokenizer=self.tokenizer,
- tokenizer_mode=self.tokenizer_mode,
- trust_remote_code=self.trust_remote_code,
- dtype=self.dtype,
- seed=self.seed,
- revision=self.revision,
- code_revision=self.code_revision,
- rope_scaling=self.rope_scaling,
- rope_theta=self.rope_theta,
- tokenizer_revision=self.tokenizer_revision,
- max_model_len=self.max_model_len,
- quantization=self.quantization,
- deepspeed_fp_bits=self.deepspeed_fp_bits,
- quant_llm_fp_bits=self.quant_llm_fp_bits,
- quant_llm_exp_bits=self.quant_llm_exp_bits,
- quantization_param_path=self.quantization_param_path,
- enforce_eager=self.enforce_eager,
- max_context_len_to_capture=self.max_context_len_to_capture,
- max_seq_len_to_capture=self.max_seq_len_to_capture,
- max_logprobs=self.max_logprobs,
- disable_sliding_window=self.disable_sliding_window,
- skip_tokenizer_init=self.skip_tokenizer_init,
- served_model_name=self.served_model_name,
- limit_mm_per_prompt=self.limit_mm_per_prompt,
- use_async_output_proc=not self.disable_async_output_proc,
- config_format=self.config_format,
- override_neuron_config=self.override_neuron_config
- )
- if model_config.is_multimodal_model:
- if self.enable_prefix_caching:
- logger.warning(
- "--enable-prefix-caching is currently not "
- "supported for multimodal models and has been disabled.")
- self.enable_prefix_caching = False
- cache_config = CacheConfig(
- block_size=self.block_size if self.device != "neuron" else
- self.max_model_len,
- gpu_memory_utilization=self.gpu_memory_utilization,
- swap_space=self.swap_space,
- cache_dtype=self.kv_cache_dtype,
- is_attention_free=model_config.is_attention_free(),
- num_gpu_blocks_override=self.num_gpu_blocks_override,
- sliding_window=model_config.get_sliding_window(),
- enable_prefix_caching=self.enable_prefix_caching,
- cpu_offload_gb=self.cpu_offload_gb,
- )
- parallel_config = ParallelConfig(
- pipeline_parallel_size=self.pipeline_parallel_size,
- tensor_parallel_size=self.tensor_parallel_size,
- worker_use_ray=self.worker_use_ray,
- max_parallel_loading_workers=self.max_parallel_loading_workers,
- disable_custom_all_reduce=self.disable_custom_all_reduce,
- tokenizer_pool_config=TokenizerPoolConfig.create_config(
- tokenizer_pool_size=self.tokenizer_pool_size,
- tokenizer_pool_type=self.tokenizer_pool_type,
- tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
- ),
- ray_workers_use_nsight=self.ray_workers_use_nsight,
- distributed_executor_backend=self.distributed_executor_backend)
- max_model_len = model_config.max_model_len
- use_long_context = max_model_len > 32768
- if self.enable_chunked_prefill is None:
- # If not explicitly set, enable chunked prefill by default for
- # long context (> 32K) models. This is to avoid OOM errors in the
- # initial memory profiling phase.
- # Chunked prefill is currently disabled for multimodal models by
- # default.
- if use_long_context and not model_config.is_multimodal_model:
- is_gpu = device_config.device_type == "cuda"
- use_sliding_window = (model_config.get_sliding_window()
- is not None)
- use_spec_decode = self.speculative_model is not None
- has_seqlen_agnostic_layers = (
- model_config.contains_seqlen_agnostic_layers(
- parallel_config))
- if (is_gpu and not use_sliding_window and not use_spec_decode
- and not self.enable_lora
- and not self.enable_prompt_adapter
- and not has_seqlen_agnostic_layers):
- self.enable_chunked_prefill = True
- logger.warning(
- "Chunked prefill is enabled by default for models with "
- "max_model_len > 32K. Currently, chunked prefill might "
- "not work with some features or models. If you "
- "encounter any issues, please disable chunked prefill "
- "by setting --enable-chunked-prefill=False.")
- if self.enable_chunked_prefill is None:
- self.enable_chunked_prefill = False
- if not self.enable_chunked_prefill and use_long_context:
- logger.warning(
- f"The model has a long context length ({max_model_len}). "
- "This may cause OOM errors during the initial memory "
- "profiling phase, or result in low performance due to small "
- "KV cache space. Consider setting --max-model-len to a "
- "smaller value.")
-
- if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
- self.use_v2_block_manager = True
- logger.warning(
- "Enabled BlockSpaceManagerV2 because it is "
- "required for multi-step scheduling.")
- speculative_config = SpeculativeConfig.maybe_create_spec_config(
- target_model_config=model_config,
- target_parallel_config=parallel_config,
- target_dtype=self.dtype,
- speculative_model=self.speculative_model,
- speculative_model_quantization = \
- self.speculative_model_quantization,
- speculative_draft_tensor_parallel_size=self.
- speculative_draft_tensor_parallel_size,
- num_speculative_tokens=self.num_speculative_tokens,
- speculative_disable_by_batch_size=self.
- speculative_disable_by_batch_size,
- speculative_max_model_len=self.speculative_max_model_len,
- enable_chunked_prefill=self.enable_chunked_prefill,
- use_v2_block_manager=self.use_v2_block_manager,
- disable_log_stats=self.disable_log_stats,
- ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
- ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
- draft_token_acceptance_method=\
- self.spec_decoding_acceptance_method,
- typical_acceptance_sampler_posterior_threshold=self.
- typical_acceptance_sampler_posterior_threshold,
- typical_acceptance_sampler_posterior_alpha=self.
- typical_acceptance_sampler_posterior_alpha,
- disable_logprobs=self.disable_logprobs_during_spec_decoding,
- )
- if self.num_scheduler_steps > 1:
- if speculative_config is not None:
- raise ValueError("Speculative decoding is not supported with "
- "multi-step (--num-scheduler-steps > 1)")
- if self.enable_chunked_prefill:
- raise ValueError("Chunked prefill is not supported with "
- "multi-step (--num-scheduler-steps > 1)")
- # make sure num_lookahead_slots is set the higher value depending on
- # if we are using speculative decoding or multi-step
- num_lookahead_slots = max(self.num_lookahead_slots,
- self.num_scheduler_steps - 1)
- num_lookahead_slots = num_lookahead_slots \
- if speculative_config is None \
- else speculative_config.num_lookahead_slots
- scheduler_config = SchedulerConfig(
- max_num_batched_tokens=self.max_num_batched_tokens,
- max_num_seqs=self.max_num_seqs,
- max_model_len=model_config.max_model_len,
- cache_config=cache_config,
- is_attention_free=model_config.is_attention_free(),
- use_v2_block_manager=self.use_v2_block_manager,
- num_lookahead_slots=num_lookahead_slots,
- delay_factor=self.scheduler_delay_factor,
- enable_chunked_prefill=self.enable_chunked_prefill,
- embedding_mode=model_config.embedding_mode,
- is_multimodal_model=model_config.is_multimodal_model,
- preemption_mode=self.preemption_mode,
- num_scheduler_steps=self.num_scheduler_steps,
- send_delta_data=(APHRODITE_USE_RAY_SPMD_WORKER and
- parallel_config.use_ray),
- single_user_mode=self.single_user_mode,
- )
- if not HAS_TRITON and self.enable_lora:
- raise ValueError("Triton is not installed, LoRA will not work.")
- lora_config = LoRAConfig(
- max_lora_rank=self.max_lora_rank,
- max_loras=self.max_loras,
- fully_sharded_loras=self.fully_sharded_loras,
- lora_extra_vocab_size=self.lora_extra_vocab_size,
- long_lora_scaling_factors=self.long_lora_scaling_factors,
- lora_dtype=self.lora_dtype,
- max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
- and self.max_cpu_loras > 0 else None) if self.enable_lora else None
- if self.qlora_adapter_name_or_path is not None and \
- self.qlora_adapter_name_or_path != "":
- if self.model_loader_extra_config is None:
- self.model_loader_extra_config = {}
- self.model_loader_extra_config[
- "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
- load_config = LoadConfig(
- load_format=self.load_format,
- download_dir=self.download_dir,
- model_loader_extra_config=self.model_loader_extra_config,
- ignore_patterns=self.ignore_patterns)
- prompt_adapter_config = PromptAdapterConfig(
- max_prompt_adapters=self.max_prompt_adapters,
- max_prompt_adapter_token=self.max_prompt_adapter_token) \
- if self.enable_prompt_adapter else None
- decoding_config = DecodingConfig(
- guided_decoding_backend=self.guided_decoding_backend)
- if (model_config.get_sliding_window() is not None
- and scheduler_config.chunked_prefill_enabled
- and not scheduler_config.use_v2_block_manager):
- raise ValueError(
- "Chunked prefill is not supported with sliding window. "
- "Set --disable-sliding-window to disable sliding window.")
- return EngineConfig(model_config=model_config,
- cache_config=cache_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- lora_config=lora_config,
- speculative_config=speculative_config,
- load_config=load_config,
- decoding_config=decoding_config,
- prompt_adapter_config=prompt_adapter_config)
- @dataclass
- class AsyncEngineArgs(EngineArgs):
- """Arguments for asynchronous Aphrodite engine."""
- disable_log_requests: bool = False
- uvloop: bool = False
- @staticmethod
- def add_cli_args(parser: FlexibleArgumentParser,
- async_args_only: bool = False) -> FlexibleArgumentParser:
- if not async_args_only:
- parser = EngineArgs.add_cli_args(parser)
- parser.add_argument('--disable-log-requests',
- action='store_true',
- help='Disable logging requests.')
- parser.add_argument(
- "--uvloop",
- action="store_true",
- help="Use the Uvloop asyncio event loop to possibly increase "
- "performance")
- return parser
- class StoreBoolean(argparse.Action):
- def __call__(self, parser, namespace, values, option_string=None):
- if values.lower() == "true":
- setattr(namespace, self.dest, True)
- elif values.lower() == "false":
- setattr(namespace, self.dest, False)
- else:
- raise ValueError(f"Invalid boolean value: {values}. "
- "Expected 'true' or 'false'.")
|