123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548 |
- import time
- from collections import deque
- from contextlib import contextmanager
- from dataclasses import dataclass
- from functools import partial
- from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
- Iterable, List, NamedTuple, Optional)
- from typing import Sequence as GenericSequence
- from typing import Set, Type, Union
- import torch
- from loguru import logger
- from typing_extensions import TypeVar
- import aphrodite.common.envs as envs
- from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
- EngineConfig, LoadConfig, LoRAConfig,
- ModelConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig,
- SpeculativeConfig)
- from aphrodite.common.logger import setup_logger
- from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
- RequestOutputFactory)
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sampling_params import RequestOutputKind, SamplingParams
- from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
- ExecuteModelRequest, Sequence,
- SequenceGroup, SequenceGroupMetadata,
- SequenceStatus)
- from aphrodite.common.utils import Counter, Device, weak_bind
- from aphrodite.engine.args_tools import EngineArgs
- from aphrodite.engine.metrics_types import StatLoggerBase, Stats
- from aphrodite.engine.output_processor.interfaces import (
- SequenceGroupOutputProcessor)
- from aphrodite.engine.output_processor.stop_checker import StopChecker
- from aphrodite.engine.output_processor.util import (
- create_output_by_sequence_group)
- from aphrodite.executor.executor_base import ExecutorBase
- from aphrodite.executor.ray_utils import initialize_ray_cluster
- from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
- InputRegistry, LLMInputs, PromptType)
- from aphrodite.inputs.preprocess import InputPreprocessor
- from aphrodite.lora.request import LoRARequest
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
- SchedulerOutputs)
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- from aphrodite.transformers_utils.config import try_get_generation_config
- from aphrodite.transformers_utils.detokenizer import Detokenizer
- from aphrodite.transformers_utils.tokenizer import AnyTokenizer
- from aphrodite.transformers_utils.tokenizer_group import (
- BaseTokenizerGroup, init_tokenizer_from_configs)
- from aphrodite.version import __version__ as APHRODITE_VERSION
- _LOCAL_LOGGING_INTERVAL_SEC = 5
- APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
- def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
- config = try_get_generation_config(
- model_config.model,
- trust_remote_code=model_config.trust_remote_code,
- revision=model_config.revision,
- )
- if config is None:
- return {}
- return config.to_diff_dict()
- _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
- _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
- @dataclass
- class SchedulerOutputState:
- """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
- scheduler_outputs: Optional[SchedulerOutputs] = None
- allow_async_output_proc: bool = False
- last_output: Optional[SamplerOutput] = None
- class OutputData(NamedTuple):
- outputs: List[SamplerOutput]
- seq_group_metadata_list: List[SequenceGroupMetadata]
- scheduler_outputs: SchedulerOutputs
- is_async: bool
- is_last_step: bool
- skip: List[int]
- class SchedulerContext:
- def __init__(self):
- self.output_queue: Deque[OutputData] = deque()
- self.request_outputs: List[Union[RequestOutput,
- EmbeddingRequestOutput]] = []
- self.seq_group_metadata_list: Optional[
- List[SequenceGroupMetadata]] = None
- self.scheduler_outputs: Optional[SchedulerOutputs] = None
- def append_output(self, outputs: List[SamplerOutput],
- seq_group_metadata_list: List[SequenceGroupMetadata],
- scheduler_outputs: SchedulerOutputs, is_async: bool,
- is_last_step: bool):
- self.output_queue.append(
- OutputData(outputs=outputs,
- seq_group_metadata_list=seq_group_metadata_list,
- scheduler_outputs=scheduler_outputs,
- is_async=is_async,
- is_last_step=is_last_step,
- skip=[]))
- class AphroditeEngine:
- """An LLM engine that receives requests and generates texts.
- This is the main class for the Aphrodite engine. It receives requests
- from clients and generates texts from the LLM. It includes a tokenizer, a
- language model (possibly distributed across multiple GPUs), and GPU memory
- space allocated for intermediate states (aka KV cache). This class utilizes
- iteration-level scheduling and efficient memory management to maximize the
- serving throughput.
- The `LLM` class wraps this class for offline batched inference and the
- `AsyncAphrodite` class wraps this class for online serving.
- NOTE: The config arguments are derived from the `EngineArgs` class. For the
- comprehensive list of arguments, see `EngineArgs`.
- Args:
- model_config: The configuration related to the LLM model.
- cache_config: The configuration related to the KV cache memory
- management.
- parallel_config: The configuration related to distributed execution.
- scheduler_config: The configuration related to the request scheduler.
- device_config: The configuration related to the device.
- lora_config (Optional): The configuration related to serving multi-LoRA.
- speculative_config (Optional): The configuration related to speculative
- decoding.
- executor_class: The model executor class for managing distributed
- execution.
- prompt_adapter_config (Optional): The configuration related to serving
- prompt adapters.
- log_stats: Whether to log statistics.
- """
- DO_VALIDATE_OUTPUT: ClassVar[bool] = False
- """A flag to toggle whether to validate the type of request output."""
- @classmethod
- @contextmanager
- def enable_output_validation(cls):
- cls.DO_VALIDATE_OUTPUT = True
- yield
- cls.DO_VALIDATE_OUTPUT = False
- @classmethod
- def validate_output(
- cls,
- output: object,
- output_type: Type[_O],
- ) -> _O:
- do_validate = cls.DO_VALIDATE_OUTPUT
- if ((TYPE_CHECKING or do_validate)
- and not isinstance(output, output_type)):
- raise TypeError(f"Expected output of type {output_type}, "
- f"but found type {type(output)}")
- return output
- @classmethod
- def validate_outputs(
- cls,
- outputs: GenericSequence[object],
- output_type: Type[_O],
- ) -> List[_O]:
- do_validate = cls.DO_VALIDATE_OUTPUT
- outputs_: List[_O]
- if TYPE_CHECKING or do_validate:
- outputs_ = []
- for output in outputs:
- if not isinstance(output, output_type):
- raise TypeError(f"Expected output of type {output_type}, "
- f"but found type {type(output)}")
- outputs_.append(output)
- else:
- outputs_ = outputs
- return outputs_
- tokenizer: Optional[BaseTokenizerGroup]
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- speculative_config: Optional[SpeculativeConfig],
- decoding_config: Optional[DecodingConfig],
- prompt_adapter_config: Optional[PromptAdapterConfig],
- executor_class: Type[ExecutorBase],
- log_stats: bool,
- stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
- input_registry: InputRegistry = INPUT_REGISTRY,
- ) -> None:
- try:
- import aphrodite.commit_id
- commit_id = True
- except ImportError:
- commit_id = False
- config_dict = {
- "Model": model_config.model,
- "Speculative Config": speculative_config,
- "DataType": model_config.dtype,
- "Model Load Format": load_config.load_format,
- "Tensor Parallel Size": parallel_config.tensor_parallel_size,
- "Pipeline Parallel Size": parallel_config.pipeline_parallel_size,
- "Disable Custom All-Reduce":
- parallel_config.disable_custom_all_reduce,
- "Quantization Format": model_config.quantization,
- "Context Length": model_config.max_model_len,
- "Enforce Eager Mode": model_config.enforce_eager,
- "Prefix Caching": cache_config.enable_prefix_caching,
- "KV Cache DataType": cache_config.cache_dtype,
- "Device": device_config.device,
- "Rope Scaling": model_config.rope_scaling,
- "Guided Decoding Backend": decoding_config,
- "Scheduler Steps": scheduler_config.num_scheduler_steps,
- "Async Output Processing": model_config.use_async_output_proc,
- }
- logger.info("-" * 85)
- if not commit_id:
- logger.info(
- f"Initializing Aphrodite Engine (v{APHRODITE_VERSION}) "
- "with the following config:")
- else:
- logger.info(f"Initializing Aphrodite Engine (v{APHRODITE_VERSION} "
- f"commit {aphrodite.__short_commit__}) with the "
- "following config:")
- for key, value in config_dict.items():
- if value is not None and not ((key == "Model Load Format" or key ==\
- "KV Cache DataType") and value == \
- "auto"):
- logger.info(f"{key} = {value!r}")
- logger.info("-" * 85)
- # TODO: Print more configs in debug mode.
- from aphrodite.plugins import load_general_plugins
- load_general_plugins()
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.speculative_config = speculative_config
- self.load_config = load_config
- self.decoding_config = decoding_config or DecodingConfig()
- self.prompt_adapter_config = prompt_adapter_config
- self.log_stats = log_stats
- if not self.model_config.skip_tokenizer_init:
- self.tokenizer = self._init_tokenizer()
- self.detokenizer = Detokenizer(self.tokenizer)
- tokenizer_group = self.get_tokenizer_group()
- else:
- self.tokenizer = None
- self.detokenizer = None
- tokenizer_group = None
- # Ensure that the function doesn't contain a reference to self,
- # to avoid engine GC issues
- def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
- assert tokenizer_group, ("tokenizer_group cannot be None, "
- "make sure skip_tokenizer_init is False")
- return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
- self.seq_counter = Counter()
- self.generation_config_fields = _load_generation_config_dict(
- model_config)
- self.input_preprocessor = InputPreprocessor(model_config,
- self.tokenizer)
- self.input_registry = input_registry
- self.input_processor = input_registry.create_input_processor(
- model_config)
- self.model_executor = executor_class(
- model_config=model_config,
- cache_config=cache_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- lora_config=lora_config,
- speculative_config=speculative_config,
- load_config=load_config,
- prompt_adapter_config=prompt_adapter_config,
- )
- if not self.model_config.embedding_mode:
- self._initialize_kv_caches()
- if self.tokenizer:
- # Ping the tokenizer to ensure liveness if it runs in a
- # different process.
- self.tokenizer.ping()
- self.cached_scheduler_outputs = [
- SchedulerOutputState()
- for _ in range(self.parallel_config.pipeline_parallel_size)
- ]
- self.scheduler_contexts = [
- SchedulerContext()
- for _ in range(self.parallel_config.pipeline_parallel_size)
- ]
- if model_config.use_async_output_proc:
- process_model_outputs = weak_bind(self._process_model_outputs)
- self.async_callbacks = [
- partial(process_model_outputs,
- ctx=self.scheduler_contexts[v_id])
- for v_id in range(self.parallel_config.pipeline_parallel_size)
- ]
- else:
- self.async_callbacks = []
- # Currently used by AsyncLLMEngine to ensure quick append
- # of request outputs to asyncio queues
- self.process_request_outputs_callback: Optional[Callable] = None
- # Create the scheduler.
- # NOTE: the cache_config here have been updated with the numbers of
- # GPU and CPU blocks, which are profiled in the distributed executor.
- self.scheduler = [
- Scheduler(
- scheduler_config, cache_config, lora_config,
- parallel_config.pipeline_parallel_size,
- self.async_callbacks[v_id]
- if model_config.use_async_output_proc else None)
- for v_id in range(parallel_config.pipeline_parallel_size)
- ]
- # Metric Logging.
- if self.log_stats:
- if stat_loggers is not None:
- self.stat_loggers = stat_loggers
- else:
- # Lazy import for prometheus multiprocessing.
- # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
- # before prometheus_client is imported.
- # See https://prometheus.github.io/client_python/multiprocess/
- from aphrodite.engine.metrics import (LoggingStatLogger,
- PrometheusStatLogger)
- self.stat_loggers = {
- "logging":
- LoggingStatLogger(
- local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
- "prometheus":
- PrometheusStatLogger(
- local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
- labels=dict(model_name=model_config.served_model_name),
- max_model_len=self.model_config.max_model_len),
- }
- self.stat_loggers["prometheus"].info("cache_config",
- self.cache_config)
- # Create sequence output processor, e.g. for beam search or
- # speculative decoding.
- self.output_processor = (
- SequenceGroupOutputProcessor.create_output_processor(
- self.scheduler_config,
- self.detokenizer,
- self.scheduler,
- self.seq_counter,
- get_tokenizer_for_seq,
- stop_checker=StopChecker(
- self.scheduler_config.max_model_len,
- get_tokenizer_for_seq,
- ),
- ))
- def _initialize_kv_caches(self) -> None:
- """Initialize the KV cache in the worker(s).
- The workers will determine the number of blocks in both the GPU cache
- and the swap CPU cache.
- """
- num_gpu_blocks, num_cpu_blocks = (
- self.model_executor.determine_num_available_blocks())
- if self.cache_config.num_gpu_blocks_override is not None:
- num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
- logger.info(f"Overriding {num_gpu_blocks=} with "
- f"{num_gpu_blocks_override=}")
- num_gpu_blocks = num_gpu_blocks_override
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
- self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
- @classmethod
- def _get_executor_cls(cls,
- engine_config: EngineConfig) -> Type[ExecutorBase]:
- distributed_executor_backend = (
- engine_config.parallel_config.distributed_executor_backend)
- # Initialize the cluster and specify the executor class.
- if isinstance(distributed_executor_backend, type):
- if not issubclass(distributed_executor_backend, ExecutorBase):
- raise TypeError(
- "distributed_executor_backend must be a subclass of "
- f"ExecutorBase. Got {distributed_executor_backend}.")
- if distributed_executor_backend.uses_ray: # type: ignore
- initialize_ray_cluster(engine_config.parallel_config)
- executor_class = distributed_executor_backend
- elif engine_config.device_config.device_type == "neuron":
- from aphrodite.executor.neuron_executor import NeuronExecutor
- executor_class = NeuronExecutor
- elif engine_config.device_config.device_type == "tpu":
- if distributed_executor_backend == "ray":
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_tpu_executor import RayTPUExecutor
- executor_class = RayTPUExecutor
- else:
- assert distributed_executor_backend is None
- from aphrodite.executor.tpu_executor import TPUExecutor
- executor_class = TPUExecutor
- elif engine_config.device_config.device_type == "cpu":
- from aphrodite.executor.cpu_executor import CPUExecutor
- executor_class = CPUExecutor
- elif engine_config.device_config.device_type == "openvino":
- from aphrodite.executor.openvino_executor import OpenVINOExecutor
- executor_class = OpenVINOExecutor
- elif engine_config.device_config.device_type == "xpu":
- if distributed_executor_backend == "ray":
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_xpu_executor import RayXPUExecutor
- executor_class = RayXPUExecutor
- elif distributed_executor_backend == "mp":
- logger.error(
- "Both start methods (spawn and fork) have issues "
- "on XPU if you use mp backend, Please try ray instead.")
- else:
- from aphrodite.executor.xpu_executor import XPUExecutor
- executor_class = XPUExecutor
- elif distributed_executor_backend == "ray":
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
- executor_class = RayGPUExecutor
- elif distributed_executor_backend == "mp":
- from aphrodite.executor.multiproc_gpu_executor import (
- MultiprocessingGPUExecutor)
- assert not envs.APHRODITE_USE_RAY_SPMD_WORKER, (
- "multiprocessing distributed executor backend does not "
- "support APHRODITE_USE_RAY_SPMD_WORKER=1")
- executor_class = MultiprocessingGPUExecutor
- else:
- from aphrodite.executor.gpu_executor import GPUExecutor
- executor_class = GPUExecutor
- return executor_class
- @classmethod
- def from_engine_args(
- cls,
- engine_args: EngineArgs,
- stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
- ) -> "AphroditeEngine":
- """Creates an Aphrodite engine from the engine arguments."""
- # Create the engine configs.
- engine_config = engine_args.create_engine_config()
- executor_class = cls._get_executor_cls(engine_config)
- # Create the LLM engine.
- engine = cls(
- **engine_config.to_dict(),
- executor_class=executor_class,
- log_stats=not engine_args.disable_log_stats,
- stat_loggers=stat_loggers,
- )
- return engine
- def __reduce__(self):
- # This is to ensure that the AphroditeEngine is not referenced in
- # the closure used to initialize Ray worker actors
- raise RuntimeError("AphroditeEngine should not be pickled!")
- def __del__(self):
- # Shutdown model executor when engine is garbage collected
- # Use getattr since __init__ can fail before the field is set
- if model_executor := getattr(self, "model_executor", None):
- model_executor.shutdown()
- def get_tokenizer_group(
- self,
- group_type: Type[_G] = BaseTokenizerGroup,
- ) -> _G:
- tokenizer_group = self.tokenizer
- if tokenizer_group is None:
- raise ValueError("Unable to get tokenizer because "
- "skip_tokenizer_init is True")
- if not isinstance(tokenizer_group, group_type):
- raise TypeError("Invalid type of tokenizer group. "
- f"Expected type: {group_type}, but "
- f"found type: {type(tokenizer_group)}")
- return tokenizer_group
- def get_tokenizer(
- self,
- lora_request: Optional[LoRARequest] = None,
- ) -> AnyTokenizer:
- return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
- def _init_tokenizer(self) -> BaseTokenizerGroup:
- return init_tokenizer_from_configs(
- model_config=self.model_config,
- scheduler_config=self.scheduler_config,
- parallel_config=self.parallel_config,
- enable_lora=bool(self.lora_config))
- def _verify_args(self) -> None:
- self.model_config.verify_with_parallel_config(self.parallel_config)
- self.cache_config.verify_with_parallel_config(self.parallel_config)
- if self.lora_config:
- self.lora_config.verify_with_model_config(self.model_config)
- self.lora_config.verify_with_scheduler_config(
- self.scheduler_config)
- if self.prompt_adapter_config:
- self.prompt_adapter_config.verify_with_model_config(
- self.model_config)
- def _add_processed_request(
- self,
- request_id: str,
- processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
- params: Union[SamplingParams, PoolingParams],
- arrival_time: float,
- lora_request: Optional[LoRARequest],
- prompt_adapter_request: Optional[PromptAdapterRequest],
- ) -> None:
- self._validate_model_inputs(processed_inputs)
- # Create the sequences.
- block_size = self.cache_config.block_size
- seq_id = next(self.seq_counter)
- eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
- seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
- lora_request, prompt_adapter_request)
- encoder_seq = None
- if 'encoder_prompt_token_ids' in processed_inputs:
- encoder_seq = Sequence(seq_id,
- processed_inputs,
- block_size,
- eos_token_id,
- lora_request,
- prompt_adapter_request,
- from_decoder_prompt=False)
- # Create a SequenceGroup based on SamplingParams or PoolingParams
- if isinstance(params, SamplingParams):
- seq_group = self._create_sequence_group_with_sampling(
- request_id,
- seq,
- params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq)
- elif isinstance(params, PoolingParams):
- seq_group = self._create_sequence_group_with_pooling(
- request_id,
- seq,
- params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq)
- else:
- raise ValueError(
- "Either SamplingParams or PoolingParams must be provided.")
- # Add the sequence group to the scheduler with least unfinished seqs.
- costs = [
- scheduler.get_num_unfinished_seq_groups()
- for scheduler in self.scheduler
- ]
- min_cost_scheduler = self.scheduler[costs.index(min(costs))]
- min_cost_scheduler.add_seq_group(seq_group)
- def stop_remote_worker_execution_loop(self) -> None:
- self.model_executor.stop_remote_worker_execution_loop()
- def add_request(
- self,
- request_id: str,
- prompt: PromptType,
- params: Union[SamplingParams, PoolingParams],
- arrival_time: Optional[float] = None,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> None:
- """Add a request to the engine's request pool.
- The request is added to the request pool and will be processed by the
- scheduler as `engine.step()` is called. The exact scheduling policy is
- determined by the scheduler.
- Args:
- request_id: The unique ID of the request.
- prompt: The prompt to the LLM. See
- :class:`~aphrodite.common.inputs.PromptType`
- for more details about the format of each input.
- params: Parameters for sampling or pooling. SamplingParams
- for text generation. PoolingParams for pooling.
- prompt_token_ids: The token IDs of the prompt. If None, we
- use the tokenizer to convert the prompts to token IDs.
- arrival_time: The arrival time of the request. If None, we use
- the current monotonic time.
- Details:
- - Set arrival_time to the current time if it is None.
- - Set prompt_token_ids to the encoded prompt if it is None.
- - Create `best_of` number of :class:`~aphrodite.common.sequence`
- objects.
- - Create a :class:`~aphrodite.common.sequenceGroup` object
- from the list of :class:`~aphrodite.common.sequence`.
- - Add the :class:`~aphrodite.common.sequenceGroup` object to the
- scheduler.
- Example:
- >>> # initialize engine
- >>> engine = AphroditeEngine.from_engine_args(engine_args)
- >>> # set request arguments
- >>> example_prompt = "Who is the president of the United States?"
- >>> sampling_params = SamplingParams(temperature=0.0)
- >>> request_id = 0
- >>>
- >>> # add the request to the engine
- >>> engine.add_request(
- >>> str(request_id),
- >>> example_prompt,
- >>> SamplingParams(temperature=0.0))
- >>> # continue the request processing
- >>> ...
- """
- if lora_request is not None and not self.lora_config:
- raise ValueError(f"Got lora_request {lora_request} but LoRA is "
- "not enabled!")
- if arrival_time is None:
- arrival_time = time.time()
- preprocessed_inputs = self.input_preprocessor.preprocess(
- prompt,
- request_id=request_id,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- processed_inputs = self.input_processor(preprocessed_inputs)
- self._add_processed_request(
- request_id=request_id,
- processed_inputs=processed_inputs,
- params=params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- def _create_sequence_group_with_sampling(
- self,
- request_id: str,
- seq: Sequence,
- sampling_params: SamplingParams,
- arrival_time: float,
- lora_request: Optional[LoRARequest],
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- encoder_seq: Optional[Sequence] = None,
- ) -> SequenceGroup:
- """Creates a SequenceGroup with SamplingParams."""
- max_logprobs = self.get_model_config().max_logprobs
- if (sampling_params.logprobs
- and sampling_params.logprobs > max_logprobs) or (
- sampling_params.prompt_logprobs
- and sampling_params.prompt_logprobs > max_logprobs):
- raise ValueError(f"Cannot request more than "
- f"{max_logprobs} logprobs.")
- # Defensive copy of SamplingParams, which are used by the sampler,
- # this doesn't deep-copy LogitsProcessor objects
- sampling_params = sampling_params.clone()
- sampling_params.update_from_generation_config(
- self.generation_config_fields, seq.eos_token_id)
- sampling_params._verify_with_scheduler_config(self.scheduler_config)
- # Create the sequence group.
- seq_group = SequenceGroup(
- request_id=request_id,
- seqs=[seq],
- arrival_time=arrival_time,
- sampling_params=sampling_params,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq)
- return seq_group
- def _create_sequence_group_with_pooling(
- self,
- request_id: str,
- seq: Sequence,
- pooling_params: PoolingParams,
- arrival_time: float,
- lora_request: Optional[LoRARequest],
- prompt_adapter_request: Optional[PromptAdapterRequest],
- encoder_seq: Optional[Sequence] = None,
- ) -> SequenceGroup:
- """Creates a SequenceGroup with PoolingParams."""
- # Defensive copy of PoolingParams, which are used by the pooler
- pooling_params = pooling_params.clone()
- # Create the sequence group.
- seq_group = SequenceGroup(
- request_id=request_id,
- seqs=[seq],
- arrival_time=arrival_time,
- lora_request=lora_request,
- pooling_params=pooling_params,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq)
- return seq_group
- def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
- """Aborts a request(s) with the given ID.
- Args:
- request_id: The ID(s) of the request to abort.
- Details:
- - Refer to the
- :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
- from class :class:`~aphrodite.processing.scheduler.Scheduler`.
- Example:
- >>> # initialize engine and add a request with request_id
- >>> request_id = str(0)
- >>> # abort the request
- >>> engine.abort_request(request_id)
- """
- for scheduler in self.scheduler:
- scheduler.abort_seq_group(request_id)
- def get_model_config(self) -> ModelConfig:
- """Gets the model configuration."""
- return self.model_config
- def get_parallel_config(self) -> ParallelConfig:
- """Gets the parallel configuration."""
- return self.parallel_config
- def get_decoding_config(self) -> DecodingConfig:
- """Gets the decoding configuration."""
- return self.decoding_config
- def get_scheduler_config(self) -> SchedulerConfig:
- """Gets the scheduler configuration."""
- return self.scheduler_config
- def get_lora_config(self) -> LoRAConfig:
- """Gets the LoRA configuration."""
- return self.lora_config
- def get_num_unfinished_requests(self) -> int:
- """Gets the number of unfinished requests."""
- return sum(scheduler.get_num_unfinished_seq_groups()
- for scheduler in self.scheduler)
- def has_unfinished_requests(self) -> bool:
- """Returns True if there are unfinished requests."""
- return any(scheduler.has_unfinished_seqs()
- for scheduler in self.scheduler)
- def has_unfinished_requests_for_virtual_engine(
- self, virtual_engine: int) -> bool:
- """
- Returns True if there are unfinished requests for the virtual engine.
- """
- return self.scheduler[virtual_engine].has_unfinished_seqs()
- @staticmethod
- def _process_sequence_group_outputs(
- seq_group: SequenceGroup,
- outputs: List[EmbeddingSequenceGroupOutput],
- ) -> None:
- seq_group.embeddings = outputs[0].embeddings
- for seq in seq_group.get_seqs():
- seq.status = SequenceStatus.FINISHED_STOPPED
- return
- def _process_model_outputs(self,
- ctx: SchedulerContext,
- request_id: Optional[str] = None) -> None:
- """Apply the model output to the sequences in the scheduled seq groups
- and return responses.
- ctx: The virtual engine context to work on
- request_id: If provided, then only this request is going to be processed
-
- """
- now = time.time()
- if len(ctx.output_queue) == 0:
- return None
- # Get pending async postprocessor
- if request_id:
- # When we process only one request, no pop is required
- # (since later we will process all of the rest)
- (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
- is_last_step, skip) = ctx.output_queue[0]
- else:
- (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
- is_last_step, skip) = ctx.output_queue.popleft()
- # Sanity check
- assert len(seq_group_metadata_list) == len(
- scheduler_outputs.scheduled_seq_groups)
- # Organize outputs by [step][sequence group] instead of
- # [sequence group][step].
- if len(outputs) > 1:
- outputs_by_sequence_group = create_output_by_sequence_group(
- outputs, num_seq_groups=len(seq_group_metadata_list))
- else:
- outputs_by_sequence_group = outputs
- # Determine the requests we need to operate on
- if request_id:
- indices = []
- for i, seq_group_meta in enumerate(seq_group_metadata_list):
- if seq_group_meta.request_id == request_id:
- assert i not in skip # Cannot be called twice
- indices.append(i)
- break
- # If the request_id was not found, then it means that
- # this is a new request that has no pending async
- # postprocessor
- if not indices:
- return
- else:
- indices = range(len(seq_group_metadata_list)) # type: ignore
- finished_before: List[int] = []
- finished_now: List[int] = []
- for i in indices:
- if i in skip:
- continue
- seq_group_meta = seq_group_metadata_list[i]
- scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
- seq_group = scheduled_seq_group.seq_group
- if seq_group.is_finished():
- finished_before.append(i)
- continue
- if len(outputs) > 1:
- output = outputs_by_sequence_group[i]
- else:
- output = [outputs_by_sequence_group[0][i]]
- if not is_async:
- seq_group.update_num_computed_tokens(
- scheduled_seq_group.token_chunk_size)
- if self.model_config.embedding_mode:
- self._process_sequence_group_outputs(seq_group, output)
- else:
- self.output_processor.process_prompt_logprob(seq_group, output)
- if seq_group_meta.do_sample:
- self.output_processor.process_outputs(
- seq_group, output, is_async)
- if seq_group.is_finished():
- finished_now.append(i)
- # Generate outputs for the requests that finished this iteration
- for i in finished_now:
- scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
- seq_group = scheduled_seq_group.seq_group
- seq_group.maybe_set_first_token_time(now)
- request_output = RequestOutputFactory.create(seq_group)
- if request_output:
- ctx.request_outputs.append(request_output)
- # When we process a single request, we skip it for the next time,
- # and invoke the request output callback (if there was final output)
- if request_id:
- assert len(indices) == 1
- skip.append(indices[0])
- if (finished_now
- and self.process_request_outputs_callback is not None):
- self.process_request_outputs_callback(ctx.request_outputs)
- ctx.request_outputs.clear()
- return
- # Free currently finished requests
- if finished_now:
- for scheduler in self.scheduler:
- scheduler.free_finished_seq_groups()
- # For multi-step, do not create outputs each iteration
- if not is_last_step:
- # Immediately process request outputs here (if callback is given)
- if (finished_now
- and self.process_request_outputs_callback is not None):
- self.process_request_outputs_callback(ctx.request_outputs)
- ctx.request_outputs.clear()
- return
- # Create the outputs
- for i in indices:
- if i in skip or i in finished_before or i in finished_now:
- continue # Avoids double processing
- scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
- seq_group = scheduled_seq_group.seq_group
- seq_group.maybe_set_first_token_time(now)
- request_output = RequestOutputFactory.create(seq_group)
- if request_output:
- ctx.request_outputs.append(request_output)
- for seq_group in scheduler_outputs.ignored_seq_groups:
- params = seq_group.sampling_params
- if params is not None and params.output_kind == (
- RequestOutputKind.DELTA) and not seq_group.is_finished():
- continue
- request_output = RequestOutputFactory.create(seq_group)
- if request_output:
- ctx.request_outputs.append(request_output)
- # Immediately process request outputs here (if callback is given)
- if (ctx.request_outputs
- and self.process_request_outputs_callback is not None):
- self.process_request_outputs_callback(ctx.request_outputs)
- ctx.request_outputs.clear()
- # For async case, we need to record the stats here.
- # For non-async case, the stats are done in the
- # LLMEngine/AsyncLLMEngine directly
- if is_async:
- # Log stats.
- self.do_log_stats(scheduler_outputs, outputs, finished_before,
- skip)
- return None
- def _advance_to_next_step(
- self, output: List[SamplerOutput],
- seq_group_metadata_list: List[SequenceGroupMetadata],
- scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
- """Given model output from a single run, append the tokens to the
- sequences. This is normally done inside output processor, but it is
- required if the worker is to perform async forward pass to next step.
- """
- for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
- zip(seq_group_metadata_list, output, scheduled_seq_groups):
- seq_group = scheduled_seq_group.seq_group
- if seq_group.is_finished():
- continue
- seq_group.update_num_computed_tokens(
- seq_group_metadata.token_chunk_size)
- if seq_group_metadata.do_sample:
- assert len(sequence_group_outputs.samples) == 1, (
- "Async output processor expects a single sample"
- " (i.e sampling_params.n == 1 and no "
- "sampling_params.best_of > 1)")
- sample = sequence_group_outputs.samples[0]
- assert len(seq_group.seqs) == 1
- seq = seq_group.seqs[0]
- seq.append_token_id(sample.output_token, sample.logprobs)
- def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
- """Performs one decoding iteration and returns newly generated results.
- .. figure:: https://i.imgur.com/sv2HssD.png
- :alt: Overview of the step function
- :align: center
- Overview of the step function.
- Details:
- - Step 1: Schedules the sequences to be executed in the next
- iteration and the token blocks to be swapped in/out/copy.
- - Depending on the scheduling policy,
- sequences may be `preempted/reordered`.
- - A Sequence Group (SG) refer to a group of sequences
- that are generated from the same prompt.
- - Step 2: Calls the distributed executor to execute the model.
- - Step 3: Processes the model output. This mainly includes:
- - Decodes the relevant outputs.
- - Updates the scheduled sequence groups with model outputs
- based on its `sampling parameters` (`use_beam_search` or not).
- - Frees the finished sequence groups.
- - Finally, it creates and returns the newly generated results.
- Example:
- >>> # Please see the example/ folder for more detailed examples.
- >>>
- >>> # initialize engine and request arguments
- >>> engine = AphroditeEngine.from_engine_args(engine_args)
- >>> example_inputs = [(0, "What is LLM?",
- >>> SamplingParams(temperature=0.0))]
- >>>
- >>> # Start the engine with an event loop
- >>> while True:
- >>> if example_inputs:
- >>> req_id, prompt, sampling_params = example_inputs.pop(0)
- >>> engine.add_request(str(req_id),prompt,sampling_params)
- >>>
- >>> # continue the request processing
- >>> request_outputs = engine.step()
- >>> for request_output in request_outputs:
- >>> if request_output.finished:
- >>> # return or show the request output
- >>>
- >>> if not (engine.has_unfinished_requests() or example_inputs):
- >>> break
- """
- if self.parallel_config.pipeline_parallel_size > 1:
- raise NotImplementedError(
- "Pipeline parallelism is only supported through AsyncAphrodite "
- "as performance will be severely degraded otherwise.")
- # For llm_engine, there is no pipeline parallel support, so the engine
- # used is always 0
- virtual_engine = 0
- # These are cached outputs from previous iterations. None if on first
- # iteration
- cached_outputs = self.cached_scheduler_outputs[virtual_engine]
- seq_group_metadata_list = cached_outputs.seq_group_metadata_list
- scheduler_outputs = cached_outputs.scheduler_outputs
- allow_async_output_proc = cached_outputs.allow_async_output_proc
- ctx = self.scheduler_contexts[virtual_engine]
- # Clear outputs for each new scheduler iteration
- ctx.request_outputs.clear()
- # Skip the scheduler if there are any remaining steps in the seq groups.
- # This ensures that the scheduler is only called again when the current
- # batch has completed.
- if not self._has_remaining_steps(seq_group_metadata_list):
- # Schedule iteration
- (seq_group_metadata_list, scheduler_outputs,
- allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
- ctx.seq_group_metadata_list = seq_group_metadata_list
- ctx.scheduler_outputs = scheduler_outputs
- # Maybe switch from async mode to sync mode
- if not allow_async_output_proc and len(ctx.output_queue) > 0:
- self._process_model_outputs(ctx=ctx)
- if (self.scheduler_config.is_multi_step
- and scheduler_outputs.num_lookahead_slots > 0):
- # cache the scheduler outputs for the next iteration if we have
- # lookahead slots
- self._cache_scheduler_outputs_for_multi_step(
- virtual_engine, seq_group_metadata_list, scheduler_outputs,
- allow_async_output_proc)
- assert seq_group_metadata_list is not None
- assert scheduler_outputs is not None
- if not scheduler_outputs.is_empty():
- finished_requests_ids = self.scheduler[
- virtual_engine].get_and_reset_finished_requests_ids()
- # Check if we have a cached last_output from the previous iteration.
- # For supporting PP this is probably the best way to pass the
- # sampled_token_ids, as a separate broadcast over all the PP stages
- # will cause one virtual engine's microbatch to block the pipeline.
- last_sampled_token_ids = \
- self._get_last_sampled_token_ids(virtual_engine)
- execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
- blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
- blocks_to_copy=scheduler_outputs.blocks_to_copy,
- num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
- running_queue_size=scheduler_outputs.running_queue_size,
- finished_requests_ids=finished_requests_ids,
- # We use ExecuteModelRequest to pass the last sampled_token_ids
- # to each of the non-last PP stages for in-place prepare_input.
- last_sampled_token_ids=last_sampled_token_ids)
- if allow_async_output_proc:
- execute_model_req.async_callback = self.async_callbacks[
- virtual_engine]
- outputs = self.model_executor.execute_model(
- execute_model_req=execute_model_req)
- # We need to do this here so that last step's sampled_token_ids can
- # be passed to the next iteration for PP.
- if self.scheduler_config.is_multi_step:
- self._update_cached_scheduler_output(virtual_engine, outputs)
- else:
- # Nothing scheduled => If there is pending async postprocessor,
- # then finish it here.
- if len(ctx.output_queue) > 0:
- self._process_model_outputs(ctx=ctx)
- # No outputs in this case
- outputs = []
- # Finish the current step for all the sequence groups.
- if self.scheduler_config.is_multi_step:
- for seq_group in seq_group_metadata_list:
- seq_group.finish_step()
- if not self._has_remaining_steps(seq_group_metadata_list):
- # clear the cache if we have finished all the steps.
- if self.scheduler_config.is_multi_step:
- self.cached_scheduler_outputs[0] = SchedulerOutputState()
- # Add results to the output_queue
- ctx.append_output(outputs=outputs,
- seq_group_metadata_list=seq_group_metadata_list,
- scheduler_outputs=scheduler_outputs,
- is_async=allow_async_output_proc,
- is_last_step=True)
- if outputs and allow_async_output_proc:
- assert len(outputs) == 1, (
- "Async postprocessor expects only a single output set")
- self._advance_to_next_step(
- outputs[0], seq_group_metadata_list,
- scheduler_outputs.scheduled_seq_groups)
- # Check if need to run the usual non-async path
- if not allow_async_output_proc:
- self._process_model_outputs(ctx=ctx)
- # Log stats.
- self.do_log_stats(scheduler_outputs, outputs)
- else:
- # Multi-step case
- return ctx.request_outputs
- if not self.has_unfinished_requests():
- # Drain async postprocessor (if exists)
- if len(ctx.output_queue) > 0:
- self._process_model_outputs(ctx=ctx)
- assert len(ctx.output_queue) == 0
- # Stop the execute model loop in parallel workers until there are
- # more requests to process. This avoids waiting indefinitely in
- # torch.distributed ops which may otherwise timeout, and unblocks
- # the RPC thread in the workers so that they can process any other
- # queued control plane messages, such as add/remove lora adapters.
- logger.debug("Stopping remote worker execution loop.")
- self.model_executor.stop_remote_worker_execution_loop()
- return ctx.request_outputs
- def _has_remaining_steps(
- self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
- ) -> bool:
- if (not self.scheduler_config.is_multi_step
- or not seq_group_metadata_list):
- return False
- # TODO: this is a sanity check for nowto make sure that all the
- # seqs are on the same steps. Eventually we will want to do some sort of
- # dynamic scheduling when doing multi-step decoding.
- ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
- if any([
- seq_group.state.remaining_steps != ref_remaining_steps
- for seq_group in seq_group_metadata_list[1:]
- ]):
- raise AssertionError(("All running sequence groups should "
- "have the same remaining steps."))
- return ref_remaining_steps > 0
- def _cache_scheduler_outputs_for_multi_step(
- self, virtual_engine: int,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- scheduler_outputs: SchedulerOutputs,
- allow_async_output_proc: bool) -> None:
- co = self.cached_scheduler_outputs[virtual_engine]
- co.seq_group_metadata_list = seq_group_metadata_list
- co.scheduler_outputs = scheduler_outputs
- co.allow_async_output_proc = allow_async_output_proc
- co.last_output = None
- def _update_cached_scheduler_output(
- self, virtual_engine: int,
- output: List[Optional[SamplerOutput]]) -> None:
- if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
- and output[0] is not None):
- last_output = output[-1]
- assert last_output is not None
- assert last_output.sampled_token_ids_cpu is not None
- assert last_output.sampled_token_ids is None
- assert last_output.sampled_token_probs is None
- self.cached_scheduler_outputs[
- virtual_engine].last_output = last_output
- def _get_last_sampled_token_ids(
- self, virtual_engine: int) -> Optional[torch.Tensor]:
- cached_last_output = self.cached_scheduler_outputs[
- virtual_engine].last_output
- if (self.scheduler_config.is_multi_step
- and self.parallel_config.pipeline_parallel_size > 1
- and cached_last_output is not None
- and cached_last_output.sampled_token_ids_cpu is not None):
- return cached_last_output.sampled_token_ids_cpu
- return None
- def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
- if not self.log_stats:
- raise RuntimeError(
- "Stat logging is disabled. Set `disable_log_stats=False` "
- "argument to enable.")
- if logger_name in self.stat_loggers:
- raise KeyError(f"Logger with name {logger_name} already exists.")
- self.stat_loggers[logger_name] = logger
- def remove_logger(self, logger_name: str) -> None:
- if not self.log_stats:
- raise RuntimeError(
- "Stat logging is disabled. Set `disable_log_stats=False` "
- "argument to enable.")
- if logger_name not in self.stat_loggers:
- raise KeyError(f"Logger with name {logger_name} does not exist.")
- del self.stat_loggers[logger_name]
- def do_log_stats(self,
- scheduler_outputs: Optional[SchedulerOutputs] = None,
- model_output: Optional[List[SamplerOutput]] = None,
- finished_before: Optional[List[int]] = None,
- skip: Optional[List[int]] = None) -> None:
- """Forced log when no requests active."""
- if self.log_stats:
- stats = self._get_stats(scheduler_outputs, model_output,
- finished_before, skip)
- for loggers in self.stat_loggers.values():
- loggers.log(stats)
- def _get_stats(self,
- scheduler_outputs: Optional[SchedulerOutputs],
- model_output: Optional[List[SamplerOutput]] = None,
- finished_before: Optional[List[int]] = None,
- skip: Optional[List[int]] = None) -> Stats:
- """Get Stats to be Logged to Prometheus.
- Args:
- scheduler_outputs: Optional, used to populate metrics related to
- the scheduled batch,
- model_output: Optional, used to emit speculative decoding metrics
- which are created by the workers.
- finished_before: Optional, indices of sequences that were finished
- before. These sequences will be ignored.
- skip: Optional, indices of sequences that were preempted. These
- sequences will be ignored.
- """
- now = time.time()
- # System State
- # Scheduler State
- num_running_sys = sum(
- len(scheduler.running) for scheduler in self.scheduler)
- num_swapped_sys = sum(
- len(scheduler.swapped) for scheduler in self.scheduler)
- num_waiting_sys = sum(
- len(scheduler.waiting) for scheduler in self.scheduler)
- # KV Cache Usage in %
- num_total_gpu = self.cache_config.num_gpu_blocks
- gpu_cache_usage_sys = 0.
- if num_total_gpu: # Guard against both None and 0
- num_free_gpu = sum(
- scheduler.block_manager.get_num_free_gpu_blocks()
- for scheduler in self.scheduler)
- gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
- num_total_cpu = self.cache_config.num_cpu_blocks
- cpu_cache_usage_sys = 0.
- if num_total_cpu: # Guard against both None and 0
- num_free_cpu = sum(
- scheduler.block_manager.get_num_free_cpu_blocks()
- for scheduler in self.scheduler)
- cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
- # Prefix Cache Hit Rate. Note that we always use
- # the cache hit rate of the first virtual engine.
- cpu_prefix_cache_hit_rate = self.scheduler[
- 0].get_prefix_cache_hit_rate(Device.CPU)
- gpu_prefix_cache_hit_rate = self.scheduler[
- 0].get_prefix_cache_hit_rate(Device.GPU)
- # Iteration stats
- num_prompt_tokens_iter = 0
- num_generation_tokens_iter = 0
- time_to_first_tokens_iter: List[float] = []
- time_per_output_tokens_iter: List[float] = []
- num_preemption_iter = (0 if scheduler_outputs is None else
- scheduler_outputs.preempted)
- # Request stats
- # Latency
- time_e2e_requests: List[float] = []
- # Metadata
- num_prompt_tokens_requests: List[int] = []
- num_generation_tokens_requests: List[int] = []
- best_of_requests: List[int] = []
- n_requests: List[int] = []
- finished_reason_requests: List[str] = []
- # NOTE: This loop assumes prefill seq_groups are before
- # decode seq_groups in scheduled_seq_groups.
- if scheduler_outputs is not None:
- # For async postprocessor, already finished sequences need to be
- # not counted (to avoid double counting)
- actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
- num_generation_tokens_from_prefill_groups = 0.
- # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
- # the len of scheduler_outputs.scheduled_seq_groups is !=
- # scheduler_outputs.num_prefill_groups, this means that
- # chunked prefills have been detected.
- for idx, scheduled_seq_group in enumerate(
- scheduler_outputs.scheduled_seq_groups):
- # Skip double logging when using async output proc
- if finished_before and idx in finished_before:
- actual_num_batched_tokens -= 1
- continue
- # Currently, skip == preempted sequences, so we need to skip
- # their log stats
- if skip and idx in skip:
- continue
- group_was_prefill = idx < scheduler_outputs.num_prefill_groups
- seq_group = scheduled_seq_group.seq_group
- # NOTE: a seq_group that completed all of its prefill tokens
- # in the last iteration will have seq_group.is_prefill() = False
- # with group_was_prefill = True
- if group_was_prefill:
- # Number of prompt tokens.
- num_prompt_tokens_iter += (
- scheduled_seq_group.token_chunk_size)
- # If the seq_group just finished the prefill state
- # get TTFT.
- if not seq_group.is_prefill():
- latency = seq_group.get_last_latency(now)
- time_to_first_tokens_iter.append(latency)
- # One generation token per finished prefill.
- num_generation_tokens_from_prefill_groups += (
- seq_group.num_seqs())
- else:
- # TPOTs.
- latency = seq_group.get_last_latency(now)
- time_per_output_tokens_iter.append(latency)
- # Because of chunked prefill, we can have a single sequence
- # group that does multiple prompt_runs. To prevent logging
- # the same metadata more than once per request, we standardize
- # on logging request level information for finished requests,
- # which can only happen once.
- if seq_group.is_finished():
- # Latency timings
- time_e2e_requests.append(now -
- seq_group.metrics.arrival_time)
- # Metadata
- num_prompt_tokens_requests.append(
- len(seq_group.prompt_token_ids))
- num_generation_tokens_requests.extend([
- seq.get_output_len()
- for seq in seq_group.get_finished_seqs()
- ])
- if seq_group.sampling_params is not None:
- best_of_requests.append(
- seq_group.sampling_params.best_of)
- n_requests.append(seq_group.sampling_params.n)
- finished_reason_requests.extend([
- SequenceStatus.get_finished_reason(seq.status)
- for seq in seq_group.get_finished_seqs()
- ])
- # Number of generation tokens.
- # num_batched_tokens equals the number of prompt_tokens plus the
- # number of decode_tokens in a single iteration. So,
- # num_generation_tokens = num_batched_tokens - num_prompt_tokens
- # + num_generation_tokens_from_prefill_groups (since we generate
- # one token on prefills on iters where the prefill finishes).
- num_generation_tokens_iter = (
- actual_num_batched_tokens - num_prompt_tokens_iter +
- num_generation_tokens_from_prefill_groups)
- # Spec decode, if enabled, emits specialized metrics from the worker in
- # sampler output.
- if model_output and (model_output[0].spec_decode_worker_metrics
- is not None):
- spec_decode_metrics = model_output[0].spec_decode_worker_metrics
- else:
- spec_decode_metrics = None
- return Stats(
- now=now,
- # System stats
- # Scheduler State
- num_running_sys=num_running_sys,
- num_swapped_sys=num_swapped_sys,
- num_waiting_sys=num_waiting_sys,
- # KV Cache Usage in %
- gpu_cache_usage_sys=gpu_cache_usage_sys,
- cpu_cache_usage_sys=cpu_cache_usage_sys,
- # Prefix Cache Hit Rate
- cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
- gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
- # Iteration stats
- num_prompt_tokens_iter=num_prompt_tokens_iter,
- num_generation_tokens_iter=num_generation_tokens_iter,
- time_to_first_tokens_iter=time_to_first_tokens_iter,
- time_per_output_tokens_iter=time_per_output_tokens_iter,
- spec_decode_metrics=spec_decode_metrics,
- num_preemption_iter=num_preemption_iter,
- # Request stats
- # Latency
- time_e2e_requests=time_e2e_requests,
- # Metadata
- num_prompt_tokens_requests=num_prompt_tokens_requests,
- num_generation_tokens_requests=num_generation_tokens_requests,
- best_of_requests=best_of_requests,
- n_requests=n_requests,
- finished_reason_requests=finished_reason_requests,
- )
- def add_lora(self, lora_request: LoRARequest) -> bool:
- return self.model_executor.add_lora(lora_request)
- def remove_lora(self, lora_id: int) -> bool:
- return self.model_executor.remove_lora(lora_id)
- def list_loras(self) -> Set[int]:
- return self.model_executor.list_loras()
- def pin_lora(self, lora_id: int) -> bool:
- return self.model_executor.pin_lora(lora_id)
- def add_prompt_adapter(
- self, prompt_adapter_request: PromptAdapterRequest) -> bool:
- return self.model_executor.add_prompt_adapter(prompt_adapter_request)
- def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
- return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
- def list_prompt_adapters(self) -> List[int]:
- return self.model_executor.list_prompt_adapters()
- def check_health(self) -> None:
- if self.tokenizer:
- self.tokenizer.check_health()
- self.model_executor.check_health()
- def shutdown(self) -> None:
- self.model_executor.stop_remote_worker_execution_loop()
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
- self.tokenizer = None
- if hasattr(self, 'scheduler'):
- self.scheduler.clear()
- if hasattr(self, 'cached_scheduler_outputs'):
- self.cached_scheduler_outputs.clear()
- if hasattr(self, 'scheduler_contexts'):
- self.scheduler_contexts.clear()
- if hasattr(self, 'stat_loggers'):
- self.stat_loggers.clear()
- if hasattr(self, 'model_executor'):
- self.model_executor.shutdown()
-
- def is_encoder_decoder_model(self):
- return self.input_preprocessor.is_encoder_decoder_model()
- def is_embedding_model(self):
- return self.model_config.is_embedding_model
- def _validate_model_inputs(self, inputs: Union[LLMInputs,
- EncoderDecoderLLMInputs]):
- if self.is_encoder_decoder_model():
- prompt_ids = inputs.get("encoder_prompt_token_ids")
- else:
- prompt_ids = inputs.get("prompt_token_ids")
- if prompt_ids is None or len(prompt_ids) == 0:
- raise ValueError("Prompt cannot be empty")
- if self.model_config.is_multimodal_model:
- max_prompt_len = self.model_config.max_model_len
- if len(prompt_ids) > max_prompt_len:
- raise ValueError(
- f"The prompt (total length {len(prompt_ids)}) is too long "
- f"to fit into the model (context length {max_prompt_len}). "
- "Make sure that `max_model_len` is no smaller than the "
- "number of text tokens plus multimodal tokens. For image "
- "inputs, the number of image tokens depends on the number "
- "of images, and possibly their aspect ratios as well.")
- # TODO: Find out how many placeholder tokens are there so we can
- # check that chunked prefill does not truncate them
- # max_batch_len = self.scheduler_config.max_num_batched_tokens
- setup_logger()
|