1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471 |
- import time
- from contextlib import contextmanager
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
- from typing import Sequence as GenericSequence
- from typing import Tuple, Type, TypeVar, Union
- from loguru import logger
- from transformers import PreTrainedTokenizer
- from typing_extensions import assert_never
- import aphrodite.common.envs as envs
- from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
- EngineConfig, LoadConfig, LoRAConfig,
- ModelConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig,
- SpeculativeConfig)
- from aphrodite.common.logger import setup_logger
- from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
- RequestOutputFactory)
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
- ExecuteModelRequest, PoolerOutput,
- SamplerOutput, Sequence, SequenceGroup,
- SequenceGroupMetadata, SequenceStatus)
- from aphrodite.common.utils import Counter, Device
- from aphrodite.engine.args_tools import EngineArgs
- from aphrodite.engine.metrics_types import StatLoggerBase, Stats
- from aphrodite.engine.output_processor.interfaces import (
- SequenceGroupOutputProcessor)
- from aphrodite.engine.output_processor.stop_checker import StopChecker
- from aphrodite.engine.output_processor.util import (
- create_output_by_sequence_group)
- from aphrodite.executor.executor_base import ExecutorBase
- from aphrodite.executor.ray_utils import initialize_ray_cluster
- from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
- InputRegistry, LLMInputs, PromptInputs,
- SingletonPromptInputs)
- from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
- from aphrodite.lora.request import LoRARequest
- from aphrodite.multimodal import MultiModalDataDict
- from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
- SchedulerOutputs)
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- from aphrodite.transformers_utils.config import try_get_generation_config
- from aphrodite.transformers_utils.detokenizer import Detokenizer
- from aphrodite.transformers_utils.tokenizer_group import (
- BaseTokenizerGroup, init_tokenizer_from_configs)
- from aphrodite.version import __version__ as APHRODITE_VERSION
- _LOCAL_LOGGING_INTERVAL_SEC = 5
- APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
- def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
- config = try_get_generation_config(
- model_config.model,
- trust_remote_code=model_config.trust_remote_code,
- revision=model_config.revision,
- )
- if config is None:
- return {}
- return config.to_diff_dict()
- _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
- PromptComponents = Tuple[Optional[str], List[int],
- Optional[MultiModalDataDict]]
- DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
- Optional[MultiModalDataDict]]
- class AphroditeEngine:
- """An LLM engine that receives requests and generates texts.
- This is the main class for the Aphrodite engine. It receives requests
- from clients and generates texts from the LLM. It includes a tokenizer, a
- language model (possibly distributed across multiple GPUs), and GPU memory
- space allocated for intermediate states (aka KV cache). This class utilizes
- iteration-level scheduling and efficient memory management to maximize the
- serving throughput.
- The `LLM` class wraps this class for offline batched inference and the
- `AsyncAphrodite` class wraps this class for online serving.
- NOTE: The config arguments are derived from the `EngineArgs` class. For the
- comprehensive list of arguments, see `EngineArgs`.
- Args:
- model_config: The configuration related to the LLM model.
- cache_config: The configuration related to the KV cache memory
- management.
- parallel_config: The configuration related to distributed execution.
- scheduler_config: The configuration related to the request scheduler.
- device_config: The configuration related to the device.
- lora_config (Optional): The configuration related to serving multi-LoRA.
- speculative_config (Optional): The configuration related to speculative
- decoding.
- executor_class: The model executor class for managing distributed
- execution.
- prompt_adapter_config (Optional): The configuration related to serving
- prompt adapters.
- log_stats: Whether to log statistics.
- """
- DO_VALIDATE_OUTPUT: ClassVar[bool] = False
- """A flag to toggle whether to validate the type of request output."""
- @classmethod
- @contextmanager
- def enable_output_validation(cls):
- cls.DO_VALIDATE_OUTPUT = True
- yield
- cls.DO_VALIDATE_OUTPUT = False
- @classmethod
- def validate_output(
- cls,
- output: object,
- output_type: Type[_O],
- ) -> _O:
- do_validate = cls.DO_VALIDATE_OUTPUT
- if ((TYPE_CHECKING or do_validate)
- and not isinstance(output, output_type)):
- raise TypeError(f"Expected output of type {output_type}, "
- f"but found type {type(output)}")
- return output
- @classmethod
- def validate_outputs(
- cls,
- outputs: GenericSequence[object],
- output_type: Type[_O],
- ) -> List[_O]:
- do_validate = cls.DO_VALIDATE_OUTPUT
- outputs_: List[_O]
- if TYPE_CHECKING or do_validate:
- outputs_ = []
- for output in outputs:
- if not isinstance(output, output_type):
- raise TypeError(f"Expected output of type {output_type}, "
- f"but found type {type(output)}")
- outputs_.append(output)
- else:
- outputs_ = outputs
- return outputs_
- tokenizer: Optional[BaseTokenizerGroup]
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- speculative_config: Optional[SpeculativeConfig],
- decoding_config: Optional[DecodingConfig],
- prompt_adapter_config: Optional[PromptAdapterConfig],
- executor_class: Type[ExecutorBase],
- log_stats: bool,
- stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
- input_registry: InputRegistry = INPUT_REGISTRY,
- ) -> None:
- try:
- import aphrodite.commit_id
- commit_id = True
- except ImportError:
- commit_id = False
- config_dict = {
- "Model": model_config.model,
- "Speculative Config": speculative_config,
- "DataType": model_config.dtype,
- "Model Load Format": load_config.load_format,
- "Tensor Parallel Size": parallel_config.tensor_parallel_size,
- "Pipeline Parallel Size": parallel_config.pipeline_parallel_size,
- "Disable Custom All-Reduce":
- parallel_config.disable_custom_all_reduce,
- "Quantization Format": model_config.quantization,
- "Context Length": model_config.max_model_len,
- "Enforce Eager Mode": model_config.enforce_eager,
- "Prefix Caching": cache_config.enable_prefix_caching,
- "KV Cache DataType": cache_config.cache_dtype,
- "Device": device_config.device,
- "Rope Scaling": model_config.rope_scaling,
- "Guided Decoding Backend": decoding_config
- }
- logger.info("-" * 85)
- if not commit_id:
- logger.info(
- f"Initializing Aphrodite Engine (v{APHRODITE_VERSION}) "
- "with the following config:")
- else:
- logger.info(f"Initializing Aphrodite Engine (v{APHRODITE_VERSION} "
- f"commit {aphrodite.__short_commit__}) with the "
- "following config:")
- for key, value in config_dict.items():
- if value is not None and not ((key == "Model Load Format" or key ==\
- "KV Cache DataType") and value == \
- "auto"):
- logger.info(f"{key} = {value!r}")
- logger.info("-" * 85)
- # TODO: Print more configs in debug mode.
- from aphrodite.plugins import load_general_plugins
- load_general_plugins()
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.speculative_config = speculative_config
- self.load_config = load_config
- self.decoding_config = decoding_config or DecodingConfig()
- self.prompt_adapter_config = prompt_adapter_config
- self.log_stats = log_stats
- if not self.model_config.skip_tokenizer_init:
- self.tokenizer = self._init_tokenizer()
- self.detokenizer = Detokenizer(self.tokenizer)
- tokenizer_group = self.get_tokenizer_group()
- else:
- self.tokenizer = None
- self.detokenizer = None
- tokenizer_group = None
- # Ensure that the function doesn't contain a reference to self,
- # to avoid engine GC issues
- def get_tokenizer_for_seq(sequence: Sequence) -> PreTrainedTokenizer:
- assert tokenizer_group, ("tokenizer_group cannot be None, "
- "make sure skip_tokenizer_init is False")
- return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
- self.seq_counter = Counter()
- self.generation_config_fields = _load_generation_config_dict(
- model_config)
- self.input_registry = input_registry
- self.input_processor = input_registry.create_input_processor(
- model_config)
- self.model_executor = executor_class(
- model_config=model_config,
- cache_config=cache_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- lora_config=lora_config,
- speculative_config=speculative_config,
- load_config=load_config,
- prompt_adapter_config=prompt_adapter_config,
- )
- if not self.model_config.embedding_mode:
- self._initialize_kv_caches()
- if self.tokenizer:
- # Ping the tokenizer to ensure liveness if it runs in a
- # different process.
- self.tokenizer.ping()
- # Create the scheduler.
- # NOTE: the cache_config here have been updated with the numbers of
- # GPU and CPU blocks, which are profiled in the distributed executor.
- self.scheduler = [
- Scheduler(scheduler_config, cache_config, lora_config,
- parallel_config.pipeline_parallel_size)
- for _ in range(parallel_config.pipeline_parallel_size)
- ]
- # Metric Logging.
- if self.log_stats:
- if stat_loggers is not None:
- self.stat_loggers = stat_loggers
- else:
- # Lazy import for prometheus multiprocessing.
- # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
- # before prometheus_client is imported.
- # See https://prometheus.github.io/client_python/multiprocess/
- from aphrodite.engine.metrics import (LoggingStatLogger,
- PrometheusStatLogger)
- self.stat_loggers = {
- "logging":
- LoggingStatLogger(
- local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
- "prometheus":
- PrometheusStatLogger(
- local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
- labels=dict(model_name=model_config.served_model_name),
- max_model_len=self.model_config.max_model_len),
- }
- self.stat_loggers["prometheus"].info("cache_config",
- self.cache_config)
- # Create sequence output processor, e.g. for beam search or
- # speculative decoding.
- self.output_processor = (
- SequenceGroupOutputProcessor.create_output_processor(
- self.scheduler_config,
- self.detokenizer,
- self.scheduler,
- self.seq_counter,
- get_tokenizer_for_seq,
- stop_checker=StopChecker(
- self.scheduler_config.max_model_len,
- get_tokenizer_for_seq,
- ),
- ))
- def _initialize_kv_caches(self) -> None:
- """Initialize the KV cache in the worker(s).
- The workers will determine the number of blocks in both the GPU cache
- and the swap CPU cache.
- """
- num_gpu_blocks, num_cpu_blocks = (
- self.model_executor.determine_num_available_blocks())
- if self.cache_config.num_gpu_blocks_override is not None:
- num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
- logger.info(f"Overriding {num_gpu_blocks=} with "
- f"{num_gpu_blocks_override=}")
- num_gpu_blocks = num_gpu_blocks_override
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
- self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
- @classmethod
- def _get_executor_cls(cls,
- engine_config: EngineConfig) -> Type[ExecutorBase]:
- distributed_executor_backend = (
- engine_config.parallel_config.distributed_executor_backend)
- # Initialize the cluster and specify the executor class.
- if isinstance(distributed_executor_backend, type):
- if not issubclass(distributed_executor_backend, ExecutorBase):
- raise TypeError(
- "distributed_executor_backend must be a subclass of "
- f"ExecutorBase. Got {distributed_executor_backend}.")
- if distributed_executor_backend.uses_ray: # type: ignore
- initialize_ray_cluster(engine_config.parallel_config)
- executor_class = distributed_executor_backend
- elif engine_config.device_config.device_type == "neuron":
- from aphrodite.executor.neuron_executor import NeuronExecutor
- executor_class = NeuronExecutor
- elif engine_config.device_config.device_type == "tpu":
- if distributed_executor_backend == "ray":
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_tpu_executor import RayTPUExecutor
- executor_class = RayTPUExecutor
- else:
- assert distributed_executor_backend is None
- from aphrodite.executor.tpu_executor import TPUExecutor
- executor_class = TPUExecutor
- elif engine_config.device_config.device_type == "cpu":
- from aphrodite.executor.cpu_executor import CPUExecutor
- executor_class = CPUExecutor
- elif engine_config.device_config.device_type == "openvino":
- from aphrodite.executor.openvino_executor import OpenVINOExecutor
- executor_class = OpenVINOExecutor
- elif engine_config.device_config.device_type == "xpu":
- if distributed_executor_backend == "ray":
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_xpu_executor import RayXPUExecutor
- executor_class = RayXPUExecutor
- else:
- from aphrodite.executor.xpu_executor import XPUExecutor
- executor_class = XPUExecutor
- elif distributed_executor_backend == "ray":
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
- executor_class = RayGPUExecutor
- elif distributed_executor_backend == "mp":
- from aphrodite.executor.multiproc_gpu_executor import (
- MultiprocessingGPUExecutor)
- assert not APHRODITE_USE_RAY_SPMD_WORKER, (
- "multiprocessing distributed executor backend does not "
- "support APHRODITE_USE_RAY_SPMD_WORKER=1")
- executor_class = MultiprocessingGPUExecutor
- else:
- from aphrodite.executor.gpu_executor import GPUExecutor
- executor_class = GPUExecutor
- return executor_class
- @classmethod
- def from_engine_args(
- cls,
- engine_args: EngineArgs,
- stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
- ) -> "AphroditeEngine":
- """Creates an LLM engine from the engine arguments."""
- # Create the engine configs.
- engine_config = engine_args.create_engine_config()
- executor_class = cls._get_executor_cls(engine_config)
- # Create the LLM engine.
- engine = cls(
- **engine_config.to_dict(),
- executor_class=executor_class,
- log_stats=not engine_args.disable_log_stats,
- stat_loggers=stat_loggers,
- )
- return engine
- def __reduce__(self):
- # This is to ensure that the AphroditeEngine is not referenced in
- # the closure used to initialize Ray worker actors
- raise RuntimeError("AphroditeEngine should not be pickled!")
- def __del__(self):
- # Shutdown the model executor when engine is garbage collected.
- # Use getattr since __init__ can fail before the field is set
- if model_executor := getattr(self, "model_executor", None):
- model_executor.shutdown()
- MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
- "skip_tokenizer_init is True")
- def get_tokenizer_group(
- self,
- fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
- if self.tokenizer is None:
- raise ValueError(fail_msg)
- return self.tokenizer
- def get_tokenizer(
- self,
- lora_request: Optional[LoRARequest] = None
- ) -> "PreTrainedTokenizer":
- return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
- def _init_tokenizer(self) -> BaseTokenizerGroup:
- return init_tokenizer_from_configs(
- model_config=self.model_config,
- scheduler_config=self.scheduler_config,
- parallel_config=self.parallel_config,
- enable_lora=bool(self.lora_config))
- def _verify_args(self) -> None:
- self.model_config.verify_with_parallel_config(self.parallel_config)
- self.cache_config.verify_with_parallel_config(self.parallel_config)
- if self.lora_config:
- self.lora_config.verify_with_model_config(self.model_config)
- self.lora_config.verify_with_scheduler_config(
- self.scheduler_config)
- if self.prompt_adapter_config:
- self.prompt_adapter_config.verify_with_model_config(
- self.model_config)
- def _get_bos_token_id(self,
- lora_request: Optional[LoRARequest] = None
- ) -> Optional[int]:
- if self.tokenizer is None:
- logger.warning("Using None for BOS token id because tokenizer "
- "is not initialized")
- return None
- return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
- def _get_eos_token_id(self,
- lora_request: Optional[LoRARequest] = None
- ) -> Optional[int]:
- if self.tokenizer is None:
- logger.warning("Using None for EOS token id because tokenizer "
- "is not initialized")
- return None
- return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
- def _get_decoder_start_token_id(self) -> Optional[int]:
- '''
- Obtain the decoder start token id employed by an encoder/decoder
- model. Returns None for non-encoder/decoder models or if the
- model config is unavailable.
- '''
- if not self.is_encoder_decoder_model():
- logger.warning("Using None for decoder start token id because "
- "this is not an encoder/decoder model.")
- return None
- if (self.model_config is None or self.model_config.hf_config is None):
- logger.warning("Using None for decoder start token id because "
- "model config is not available.")
- return None
- dec_start_token_id = getattr(self.model_config.hf_config,
- 'decoder_start_token_id', None)
- if dec_start_token_id is None:
- logger.warning("Falling back on <BOS> for decoder start token id "
- "because decoder start token id is not available.")
- dec_start_token_id = self._get_bos_token_id()
- return dec_start_token_id
- def _add_processed_request(
- self,
- request_id: str,
- processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
- params: Union[SamplingParams, PoolingParams],
- arrival_time: float,
- lora_request: Optional[LoRARequest],
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> None:
- # Create the sequences.
- block_size = self.cache_config.block_size
- seq_id = next(self.seq_counter)
- eos_token_id = self._get_eos_token_id(lora_request)
- seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
- lora_request, prompt_adapter_request)
- encoder_seq = None
- if 'encoder_prompt_token_ids' in processed_inputs:
- encoder_seq = Sequence(seq_id,
- processed_inputs,
- block_size,
- eos_token_id,
- lora_request,
- prompt_adapter_request,
- from_decoder_prompt=False)
- # Create a SequenceGroup based on SamplingParams or PoolingParams
- if isinstance(params, SamplingParams):
- seq_group = self._create_sequence_group_with_sampling(
- request_id,
- seq,
- params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq,
- )
- elif isinstance(params, PoolingParams):
- seq_group = self._create_sequence_group_with_pooling(
- request_id,
- seq,
- params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq,
- )
- else:
- raise ValueError(
- "Either SamplingParams or PoolingParams must be provided.")
- # Add the sequence group to the scheduler with least unfinished seqs.
- costs = [
- scheduler.get_num_unfinished_seq_groups()
- for scheduler in self.scheduler
- ]
- min_cost_scheduler = self.scheduler[costs.index(min(costs))]
- min_cost_scheduler.add_seq_group(seq_group)
- def stop_remote_worker_execution_loop(self) -> None:
- self.model_executor.stop_remote_worker_execution_loop()
- _LLMInputComponentsType = Tuple[str, List[int]]
- def _prepare_decoder_input_ids_for_generation(
- self,
- decoder_input_ids: Optional[List[int]],
- ) -> List[int]:
- """
- Prepares `decoder_input_ids` for generation with encoder-decoder models.
- Based on
- https://github.com/huggingface/transformers/blob/
- 4037a2b5b1278736e566aec12e169100275545ea/
- src/transformers/generation/utils.py
- specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
- Arguments:
- * decoder_input_ids: input token ids to preprocess
- Returns:
- * Processed token list
- """
- decoder_start_token_id = self._get_decoder_start_token_id()
- assert decoder_start_token_id is not None
- if decoder_input_ids is None:
- # no decoder prompt input ->
- # use decoder_start_token_id as decoder_input_ids
- decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
- if (len(decoder_input_ids) == 0
- or decoder_input_ids[0] != decoder_start_token_id):
- decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
- return decoder_input_ids
- def _tokenize_prompt(
- self,
- prompt: str,
- request_id: str,
- lora_request: Optional[LoRARequest],
- ) -> List[int]:
- '''
- Wrapper around application of the model's tokenizer.
- Arguments:
- * prompt
- * request_id
- * lora_request
- Returns:
- * prompt token ids
- '''
- tokenizer = self.get_tokenizer_group("prompts must be None if "
- "skip_tokenizer_init is True")
- return tokenizer.encode(request_id=request_id,
- prompt=prompt,
- lora_request=lora_request)
- def _extract_prompt_components(
- self,
- inputs: SingletonPromptInputs,
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- ) -> PromptComponents:
- '''
- Extract the components of any single encoder or decoder input prompt.
- Arguments:
- * request_id
- * inputs: single encoder or decoder input prompt
- * lora_request: this is only valid for decoder prompts
- Returns:
- * prompt
- * prompt_token_ids
- * multi_modal_data
- '''
- if isinstance(inputs, str):
- prompt = inputs
- prompt_token_ids = self._tokenize_prompt(
- prompt,
- request_id=request_id,
- lora_request=lora_request,
- )
- multi_modal_data = None
- elif isinstance(inputs, dict):
- if "prompt_token_ids" in inputs:
- prompt = None
- prompt_token_ids = inputs["prompt_token_ids"]
- else:
- # NOTE: This extra assignment is required to pass mypy
- prompt = parsed_prompt = inputs["prompt"]
- prompt_token_ids = self._tokenize_prompt(
- parsed_prompt,
- request_id=request_id,
- lora_request=lora_request,
- )
- multi_modal_data = inputs.get("multi_modal_data")
- else:
- assert_never(inputs)
- return prompt, prompt_token_ids, multi_modal_data
- def _apply_prompt_adapter(
- self,
- prompt_token_ids: List[int],
- prompt_adapter_request: Optional[PromptAdapterRequest],
- ) -> List[int]:
- if prompt_adapter_request:
- prompt_token_ids = (
- [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
- + prompt_token_ids)
- return prompt_token_ids
- def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
- '''
- Specifically for encoder/decoder models:
- generate a default decoder prompt for when
- the user specifies only the encoder prompt.
- Encoder/decoder models utilize the decoder
- prompt in different ways; as new models are
- added, it is intended that this function
- will be extended to produce differing
- default decoder prompts, depending on the
- model variety.
- Absent a special case, the default behavior
- of this method is to mirror the behavior of
- the HuggingFace (HF) GenerationMixin for a None
- decoder prompt, which is to employ a logit processor
- setting to force the first decoded token to be <BOS>.
- Here, this behavior is approximated by having the
- "default" decoder prompt be <BOS>.
- However, it is possible that in the future
- other models may have different or more
- complex logic for the default decoder prompt.
- This motivates having a special helper method
- for default decoder prompts.
- Returns:
- * prompt_token_ids
- '''
- bos_token_id = self._get_bos_token_id()
- assert bos_token_id is not None
- return [bos_token_id]
- def _build_enc_dec_llm_inputs(
- self,
- encoder_comps: PromptComponents,
- decoder_comps: DecoderPromptComponents,
- ) -> EncoderDecoderLLMInputs:
- encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
- decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
- if encoder_mm_data is not None or decoder_mm_data is not None:
- raise ValueError("Multi-modal encoder-decoder models are "
- "not supported yet")
- decoder_prompt_ids = (
- self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
- return EncoderDecoderLLMInputs(
- prompt_token_ids=decoder_prompt_ids,
- prompt=decoder_prompt,
- encoder_prompt_token_ids=encoder_prompt_ids,
- encoder_prompt=encoder_prompt,
- )
- def _process_encoder_decoder_prompt(
- self,
- inputs: PromptInputs,
- request_id: str,
- ) -> EncoderDecoderLLMInputs:
- '''
- For encoder/decoder models only:
- Process an input prompt into an
- :class:`EncoderDecoderLLMInputs` instance.
- There are two types of input prompts:
- singleton prompts which carry only the
- encoder prompt, and explicit encoder/decoder
- prompts which carry both the encoder and the
- decoder prompts as member variables.
- This function handles the following scenarios:
- * Singleton encoder prompt: extract encoder prompt
- token ids & infer default decoder prompt token ids
- * Explicit encoder/decoder prompt: extract encoder
- and decoder prompt token ids
- Note that for Explicit encoder/decoder prompts,
- each sub-prompt (encoder or decoder prompt) can
- have any possible singleton type; thus this
- method relies on helper functions to obtain
- token ids for the sub-prompts.
-
- Arguments:
- * inputs: an input prompt
- * request_id
- Returns:
- * :class:`EncoderDecoderLLMInputs` instance
- '''
- encoder_comps: PromptComponents
- decoder_comps: DecoderPromptComponents
- if is_explicit_encoder_decoder_prompt(inputs):
- encoder_comps = self._extract_prompt_components(
- inputs["encoder_prompt"],
- request_id=request_id,
- )
- if (decoder_input := inputs["decoder_prompt"]) is None:
- decoder_comps = None, None, None
- else:
- decoder_comps = self._extract_prompt_components(
- decoder_input,
- request_id=request_id,
- )
- else:
- encoder_comps = self._extract_prompt_components(
- inputs,
- request_id=request_id,
- )
- decoder_comps = None, None, None
- return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
- def _build_decoder_only_llm_inputs(
- self,
- prompt_comps: PromptComponents,
- prompt_adapter_request: Optional[PromptAdapterRequest],
- ) -> LLMInputs:
- prompt, prompt_token_ids, multi_modal_data = prompt_comps
- prompt_token_ids = self._apply_prompt_adapter(
- prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
- return LLMInputs(prompt_token_ids=prompt_token_ids,
- prompt=prompt,
- multi_modal_data=multi_modal_data)
- def _process_decoder_only_prompt(
- self,
- inputs: SingletonPromptInputs,
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> LLMInputs:
- '''
- For decoder-only models:
- Process an input prompt into an :class:`LLMInputs` instance.
- Arguments:
- * inputs: input prompt
- * request_id
- * lora_request
- * prompt_adapter_request
- Returns:
- * :class:`LLMInputs` instance
- '''
- prompt_comps = self._extract_prompt_components(
- inputs,
- request_id=request_id,
- lora_request=lora_request,
- )
- return self._build_decoder_only_llm_inputs(
- prompt_comps,
- prompt_adapter_request=prompt_adapter_request,
- )
- def process_model_inputs(
- self,
- inputs: PromptInputs,
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
- if self.is_encoder_decoder_model():
- # Encoder-decoder model requires special mapping of
- # input prompts to encoder & decoder
- model_inputs = self._process_encoder_decoder_prompt(
- inputs,
- request_id=request_id,
- )
- else:
- if is_explicit_encoder_decoder_prompt(inputs):
- raise ValueError("Cannot pass encoder-decoder prompt "
- "to decoder-only models")
- # Decoder-only operation
- model_inputs = self._process_decoder_only_prompt(
- inputs,
- request_id=request_id,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- return self.input_processor(model_inputs)
- def add_request(
- self,
- request_id: str,
- inputs: PromptInputs,
- params: Union[SamplingParams, PoolingParams],
- arrival_time: Optional[float] = None,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> None:
- """Add a request to the engine's request pool.
- The request is added to the request pool and will be processed by the
- scheduler as `engine.step()` is called. The exact scheduling policy is
- determined by the scheduler.
- Args:
- request_id: The unique ID of the request.
- prompt: The prompt string. Can be None if prompt_token_ids is
- provided.
- params: Parameters for sampling or pooling. SamplingParams
- for text generation. PoolingParams for pooling.
- prompt_token_ids: The token IDs of the prompt. If None, we
- use the tokenizer to convert the prompts to token IDs.
- arrival_time: The arrival time of the request. If None, we use
- the current monotonic time.
- multi_modal_data: Multi modal data per request.
- Details:
- - Set arrival_time to the current time if it is None.
- - Set prompt_token_ids to the encoded prompt if it is None.
- - Create `best_of` number of :class:`~aphrodite.common.sequence`
- objects.
- - Create a :class:`~aphrodite.common.sequenceGroup` object
- from the list of :class:`~aphrodite.common.sequence`.
- - Add the :class:`~aphrodite.common.sequenceGroup` object to the
- scheduler.
- Example:
- >>> # initialize engine
- >>> engine = AphroditeEngine.from_engine_args(engine_args)
- >>> # set request arguments
- >>> example_prompt = "Who is the president of the United States?"
- >>> sampling_params = SamplingParams(temperature=0.0)
- >>> request_id = 0
- >>>
- >>> # add the request to the engine
- >>> engine.add_request(
- >>> str(request_id),
- >>> example_prompt,
- >>> SamplingParams(temperature=0.0))
- >>> # continue the request processing
- >>> ...
- """
- if lora_request is not None and not self.lora_config:
- raise ValueError(f"Got lora_request {lora_request} but LoRA is "
- "not enabled!")
- if arrival_time is None:
- arrival_time = time.time()
- processed_inputs = self.process_model_inputs(
- inputs,
- request_id=request_id,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- self._add_processed_request(
- request_id=request_id,
- processed_inputs=processed_inputs,
- params=params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- def _create_sequence_group_with_sampling(
- self,
- request_id: str,
- seq: Sequence,
- sampling_params: SamplingParams,
- arrival_time: float,
- lora_request: Optional[LoRARequest],
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- encoder_seq: Optional[Sequence] = None,
- ) -> SequenceGroup:
- """Creates a SequenceGroup with SamplingParams."""
- max_logprobs = self.get_model_config().max_logprobs
- if (sampling_params.logprobs
- and sampling_params.logprobs > max_logprobs) or (
- sampling_params.prompt_logprobs
- and sampling_params.prompt_logprobs > max_logprobs):
- raise ValueError(f"Cannot request more than "
- f"{max_logprobs} logprobs.")
- # Defensive copy of SamplingParams, which are used by the sampler,
- # this doesn't deep-copy LogitsProcessor objects
- sampling_params = sampling_params.clone()
- sampling_params.update_from_generation_config(
- self.generation_config_fields, seq.eos_token_id)
- # Create the sequence group.
- seq_group = SequenceGroup(
- request_id=request_id,
- seqs=[seq],
- arrival_time=arrival_time,
- sampling_params=sampling_params,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq)
- return seq_group
- def _create_sequence_group_with_pooling(
- self,
- request_id: str,
- seq: Sequence,
- pooling_params: PoolingParams,
- arrival_time: float,
- lora_request: Optional[LoRARequest],
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- encoder_seq: Optional[Sequence] = None,
- ) -> SequenceGroup:
- """Creates a SequenceGroup with PoolingParams."""
- # Defensive copy of PoolingParams, which are used by the pooler
- pooling_params = pooling_params.clone()
- # Create the sequence group.
- seq_group = SequenceGroup(
- request_id=request_id,
- seqs=[seq],
- arrival_time=arrival_time,
- lora_request=lora_request,
- pooling_params=pooling_params,
- prompt_adapter_request=prompt_adapter_request,
- encoder_seq=encoder_seq)
- return seq_group
- def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
- """Aborts a request(s) with the given ID.
- Args:
- request_id: The ID(s) of the request to abort.
- Details:
- - Refer to the
- :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
- from class :class:`~aphrodite.processing.scheduler.Scheduler`.
- Example:
- >>> # initialize engine and add a request with request_id
- >>> request_id = str(0)
- >>> # abort the request
- >>> engine.abort_request(request_id)
- """
- for scheduler in self.scheduler:
- scheduler.abort_seq_group(request_id)
- def get_model_config(self) -> ModelConfig:
- """Gets the model configuration."""
- return self.model_config
- def get_parallel_config(self) -> ParallelConfig:
- """Gets the parallel configuration."""
- return self.parallel_config
- def get_decoding_config(self) -> DecodingConfig:
- """Gets the decoding configuration."""
- return self.decoding_config
- def get_scheduler_config(self) -> SchedulerConfig:
- """Gets the scheduler configuration."""
- return self.scheduler_config
- def get_lora_config(self) -> LoRAConfig:
- """Gets the LoRA configuration."""
- return self.lora_config
- def get_num_unfinished_requests(self) -> int:
- """Gets the number of unfinished requests."""
- return sum(scheduler.get_num_unfinished_seq_groups()
- for scheduler in self.scheduler)
- def has_unfinished_requests(self) -> bool:
- """Returns True if there are unfinished requests."""
- return any(scheduler.has_unfinished_seqs()
- for scheduler in self.scheduler)
- def has_unfinished_requests_for_virtual_engine(
- self, virtual_engine: int) -> bool:
- """
- Returns True if there are unfinished requests for the virtual engine.
- """
- return self.scheduler[virtual_engine].has_unfinished_seqs()
- def _process_sequence_group_outputs(
- self,
- seq_group: SequenceGroup,
- outputs: List[EmbeddingSequenceGroupOutput],
- ) -> None:
- seq_group.embeddings = outputs[0].embeddings
- for seq in seq_group.get_seqs():
- seq.status = SequenceStatus.FINISHED_STOPPED
- return
- def _process_model_outputs(
- self,
- output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
- scheduled_seq_groups: List[ScheduledSequenceGroup],
- ignored_seq_groups: List[SequenceGroup],
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
- """Apply the model output to the sequences in the scheduled seq groups.
- Returns RequestOutputs that can be returned to the client.
- """
- now = time.time()
- # Organize outputs by [sequence group][step] instead of
- # [step][sequence group].
- output_by_sequence_group = create_output_by_sequence_group(
- output, num_seq_groups=len(scheduled_seq_groups))
- # Update the scheduled sequence groups with the model outputs.
- for scheduled_seq_group, outputs, seq_group_meta in zip(
- scheduled_seq_groups, output_by_sequence_group,
- seq_group_metadata_list):
- seq_group = scheduled_seq_group.seq_group
- seq_group.update_num_computed_tokens(
- scheduled_seq_group.token_chunk_size)
- if self.model_config.embedding_mode:
- self._process_sequence_group_outputs(seq_group, outputs)
- continue
- self.output_processor.process_prompt_logprob(seq_group, outputs)
- if seq_group_meta.do_sample:
- self.output_processor.process_outputs(seq_group, outputs)
- # Free the finished sequence groups.
- for scheduler in self.scheduler:
- scheduler.free_finished_seq_groups()
- # Create the outputs.
- request_outputs: List[Union[RequestOutput,
- EmbeddingRequestOutput]] = []
- for scheduled_seq_group in scheduled_seq_groups:
- seq_group = scheduled_seq_group.seq_group
- seq_group.maybe_set_first_token_time(now)
- request_output = RequestOutputFactory.create(seq_group)
- request_outputs.append(request_output)
- for seq_group in ignored_seq_groups:
- request_output = RequestOutputFactory.create(seq_group)
- request_outputs.append(request_output)
- return request_outputs
- def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
- """Performs one decoding iteration and returns newly generated results.
- .. figure:: https://i.imgur.com/sv2HssD.png
- :alt: Overview of the step function
- :align: center
- Overview of the step function.
- Details:
- - Step 1: Schedules the sequences to be executed in the next
- iteration and the token blocks to be swapped in/out/copy.
- - Depending on the scheduling policy,
- sequences may be `preempted/reordered`.
- - A Sequence Group (SG) refer to a group of sequences
- that are generated from the same prompt.
- - Step 2: Calls the distributed executor to execute the model.
- - Step 3: Processes the model output. This mainly includes:
- - Decodes the relevant outputs.
- - Updates the scheduled sequence groups with model outputs
- based on its `sampling parameters` (`use_beam_search` or not).
- - Frees the finished sequence groups.
- - Finally, it creates and returns the newly generated results.
- Example:
- >>> # Please see the example/ folder for more detailed examples.
- >>>
- >>> # initialize engine and request arguments
- >>> engine = AphroditeEngine.from_engine_args(engine_args)
- >>> example_inputs = [(0, "What is LLM?",
- >>> SamplingParams(temperature=0.0))]
- >>>
- >>> # Start the engine with an event loop
- >>> while True:
- >>> if example_inputs:
- >>> req_id, prompt, sampling_params = example_inputs.pop(0)
- >>> engine.add_request(str(req_id), prompt, sampling_params)
- >>>
- >>> # continue the request processing
- >>> request_outputs = engine.step()
- >>> for request_output in request_outputs:
- >>> if request_output.finished:
- >>> # return or show the request output
- >>>
- >>> if not (engine.has_unfinished_requests() or example_inputs):
- >>> break
- """
- if self.parallel_config.pipeline_parallel_size > 1:
- raise NotImplementedError(
- "Pipeline parallelism is only supported through AsyncAphrodite "
- "as performance will be severely degraded otherwise.")
- if self.scheduler_config.num_scheduler_steps > 1:
- raise NotImplementedError(
- "Multiple scheduler steps (multi-step) are only supported "
- "through AsyncAphrodite.")
- seq_group_metadata_list, scheduler_outputs = self.scheduler[
- 0].schedule()
- if not scheduler_outputs.is_empty():
- finished_requests_ids = self.scheduler[
- 0].get_and_reset_finished_requests_ids()
- execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
- blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
- blocks_to_copy=scheduler_outputs.blocks_to_copy,
- num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
- running_queue_size=scheduler_outputs.running_queue_size,
- finished_requests_ids=finished_requests_ids,
- )
- output = self.model_executor.execute_model(
- execute_model_req=execute_model_req)
- else:
- output = []
- request_outputs = self._process_model_outputs(
- output, scheduler_outputs.scheduled_seq_groups,
- scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
- # Log stats.
- self.do_log_stats(scheduler_outputs, output)
- if not self.has_unfinished_requests():
- # Stop the execute model loop in parallel workers until there are
- # more requests to process. This avoids waiting indefinitely in
- # torch.distributed ops which may otherwise timeout, and unblocks
- # the RPC thread in the workers so that they can process any other
- # queued control plane messages, such as add/remove lora adapters.
- self.model_executor.stop_remote_worker_execution_loop()
- return request_outputs
- def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
- if logger_name in self.stat_loggers:
- raise KeyError(f"Logger with name {logger_name} already exists.")
- self.stat_loggers[logger_name] = logger
- def remove_logger(self, logger_name: str) -> None:
- if logger_name not in self.stat_loggers:
- raise KeyError(f"Logger with name {logger_name} does not exist.")
- del self.stat_loggers[logger_name]
- def do_log_stats(
- self,
- scheduler_outputs: Optional[SchedulerOutputs] = None,
- model_output: Optional[List[SamplerOutput]] = None) -> None:
- """Forced log when no requests active."""
- if self.log_stats:
- stats = self._get_stats(scheduler_outputs, model_output)
- for loggers in self.stat_loggers.values():
- loggers.log(stats)
- def _get_stats(
- self,
- scheduler_outputs: Optional[SchedulerOutputs],
- model_output: Optional[List[SamplerOutput]] = None) -> Stats:
- """Get Stats to be Logged to Prometheus.
- Args:
- scheduler_outputs: Optional, used to populate metrics related to
- the scheduled batch,
- model_output: Optional, used to emit speculative decoding metrics
- which are created by the workers.
- """
- now = time.time()
- # System State
- # Scheduler State
- num_running_sys = sum(
- len(scheduler.running) for scheduler in self.scheduler)
- num_swapped_sys = sum(
- len(scheduler.swapped) for scheduler in self.scheduler)
- num_waiting_sys = sum(
- len(scheduler.waiting) for scheduler in self.scheduler)
- # KV Cache Usage in %
- num_total_gpu = self.cache_config.num_gpu_blocks
- gpu_cache_usage_sys = 0.
- if num_total_gpu is not None:
- num_free_gpu = sum(
- scheduler.block_manager.get_num_free_gpu_blocks()
- for scheduler in self.scheduler)
- if not self.model_config.is_attention_free():
- gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
- else:
- gpu_cache_usage_sys = 0.0
- num_total_cpu = self.cache_config.num_cpu_blocks
- cpu_cache_usage_sys = 0.
- if num_total_cpu is not None and num_total_cpu > 0:
- num_free_cpu = sum(
- scheduler.block_manager.get_num_free_cpu_blocks()
- for scheduler in self.scheduler)
- if not self.model_config.is_attention_free():
- cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
- else:
- cpu_cache_usage_sys = 0.0
- # Prefix Cache Hit Rate. Note that we always use
- # the cache hit rate of the first virtual engine.
- cpu_prefix_cache_hit_rate = self.scheduler[
- 0].get_prefix_cache_hit_rate(Device.CPU)
- gpu_prefix_cache_hit_rate = self.scheduler[
- 0].get_prefix_cache_hit_rate(Device.GPU)
- # Iteration stats
- num_prompt_tokens_iter = 0
- num_generation_tokens_iter = 0
- time_to_first_tokens_iter: List[float] = []
- time_per_output_tokens_iter: List[float] = []
- num_preemption_iter = (0 if scheduler_outputs is None else
- scheduler_outputs.preempted)
- # Request stats
- # Latency
- time_e2e_requests: List[float] = []
- # Metadata
- num_prompt_tokens_requests: List[int] = []
- num_generation_tokens_requests: List[int] = []
- best_of_requests: List[int] = []
- n_requests: List[int] = []
- finished_reason_requests: List[str] = []
- # NOTE: This loop assumes prefill seq_groups are before
- # decode seq_groups in scheduled_seq_groups.
- if scheduler_outputs is not None:
- num_generation_tokens_from_prefill_groups = 0.
- # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
- # the len of scheduler_outputs.scheduled_seq_groups is !=
- # scheduler_outputs.num_prefill_groups, this means that
- # chunked prefills have been detected.
- for idx, scheduled_seq_group in enumerate(
- scheduler_outputs.scheduled_seq_groups):
- group_was_prefill = idx < scheduler_outputs.num_prefill_groups
- seq_group = scheduled_seq_group.seq_group
- # NOTE: a seq_group that completed all of its prefill tokens
- # in the last iteration will have seq_group.is_prefill() = False
- # with group_was_prefill = True
- if group_was_prefill:
- # Number of prompt tokens.
- num_prompt_tokens_iter += (
- scheduled_seq_group.token_chunk_size)
- # If the seq_group just finished the prefill state
- # get TTFT.
- if not seq_group.is_prefill():
- latency = seq_group.get_last_latency(now)
- time_to_first_tokens_iter.append(latency)
- # One generation token per finished prefill.
- num_generation_tokens_from_prefill_groups += (
- seq_group.num_seqs())
- else:
- # TPOTs.
- latency = seq_group.get_last_latency(now)
- time_per_output_tokens_iter.append(latency)
- # Because of chunked prefill, we can have a single sequence
- # group that does multiple prompt_runs. To prevent logging
- # the same metadata more than once per request, we standardize
- # on logging request level information for finished requests,
- # which can only happen once.
- if seq_group.is_finished():
- # Latency timings
- time_e2e_requests.append(now -
- seq_group.metrics.arrival_time)
- # Metadata
- num_prompt_tokens_requests.append(
- len(seq_group.prompt_token_ids))
- num_generation_tokens_requests.extend([
- seq.get_output_len()
- for seq in seq_group.get_finished_seqs()
- ])
- if seq_group.sampling_params is not None:
- best_of_requests.append(
- seq_group.sampling_params.best_of)
- n_requests.append(seq_group.sampling_params.n)
- finished_reason_requests.extend([
- SequenceStatus.get_finished_reason(seq.status)
- for seq in seq_group.get_finished_seqs()
- ])
- # Number of generation tokens.
- # num_batched_tokens equals the number of prompt_tokens plus the
- # number of decode_tokens in a single iteration. So,
- # num_generation_tokens = num_batched_tokens - num_prompt_tokens
- # + num_generation_tokens_from_prefill_groups (since we generate
- # one token on prefills on iters where the prefill finishes).
- num_generation_tokens_iter = (
- scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
- num_generation_tokens_from_prefill_groups)
- # Spec decode, if enabled, emits specialized metrics from the worker in
- # sampler output.
- if model_output and (model_output[0].spec_decode_worker_metrics
- is not None):
- spec_decode_metrics = model_output[0].spec_decode_worker_metrics
- else:
- spec_decode_metrics = None
- return Stats(
- now=now,
- # System stats
- # Scheduler State
- num_running_sys=num_running_sys,
- num_swapped_sys=num_swapped_sys,
- num_waiting_sys=num_waiting_sys,
- # KV Cache Usage in %
- gpu_cache_usage_sys=gpu_cache_usage_sys,
- cpu_cache_usage_sys=cpu_cache_usage_sys,
- # Prefix Cache Hit Rate
- cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
- gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
- # Iteration stats
- num_prompt_tokens_iter=num_prompt_tokens_iter,
- num_generation_tokens_iter=num_generation_tokens_iter,
- time_to_first_tokens_iter=time_to_first_tokens_iter,
- time_per_output_tokens_iter=time_per_output_tokens_iter,
- spec_decode_metrics=spec_decode_metrics,
- num_preemption_iter=num_preemption_iter,
- # Request stats
- # Latency
- time_e2e_requests=time_e2e_requests,
- # Metadata
- num_prompt_tokens_requests=num_prompt_tokens_requests,
- num_generation_tokens_requests=num_generation_tokens_requests,
- best_of_requests=best_of_requests,
- n_requests=n_requests,
- finished_reason_requests=finished_reason_requests,
- )
- def add_lora(self, lora_request: LoRARequest) -> bool:
- return self.model_executor.add_lora(lora_request)
- def remove_lora(self, lora_id: int) -> bool:
- return self.model_executor.remove_lora(lora_id)
- def list_loras(self) -> List[int]:
- return self.model_executor.list_loras()
- def pin_lora(self, lora_id: int) -> bool:
- return self.model_executor.pin_lora(lora_id)
- def add_prompt_adapter(
- self, prompt_adapter_request: PromptAdapterRequest) -> bool:
- return self.model_executor.add_prompt_adapter(prompt_adapter_request)
- def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
- return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
- def list_prompt_adapters(self) -> List[int]:
- return self.model_executor.list_prompt_adapters()
- def check_health(self) -> None:
- if self.tokenizer:
- self.tokenizer.check_health()
- self.model_executor.check_health()
- def is_encoder_decoder_model(self):
- return self.model_config.is_encoder_decoder_model
- def is_embedding_model(self):
- return self.model_config.is_embedding_model
- setup_logger()
|