123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878 |
- import time
- from typing import Iterable, List, Optional, Tuple, Type, Union
- from loguru import logger
- from transformers import PreTrainedTokenizer
- import aphrodite
- from aphrodite.lora.request import LoRARequest
- from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
- ParallelConfig, SchedulerConfig,
- LoRAConfig, VisionLanguageConfig,
- SpeculativeConfig)
- from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
- from aphrodite.engine.args_tools import EngineArgs
- from aphrodite.executor.executor_base import ExecutorBase
- from aphrodite.engine.metrics import StatLogger, Stats
- from aphrodite.engine.ray_tools import (initialize_ray_cluster)
- from aphrodite.common.outputs import RequestOutput
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
- SequenceGroupOutput, SequenceOutput,
- SequenceStatus, MultiModalData)
- from aphrodite.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
- get_tokenizer_group)
- from aphrodite.transformers_utils.detokenizer import Detokenizer
- from aphrodite.common.utils import (
- Counter, )
- from aphrodite.common.logger import setup_logger
- _LOCAL_LOGGING_INTERVAL_SEC = 5
- class AphroditeEngine:
- """An LLM engine that receives requests and generates texts.
- This is the main class for the Aphrodite engine. It receives requests
- from clients and generates texts from the LLM. It includes a tokenizer, a
- language model (possibly distributed across multiple GPUs), and GPU memory
- space allocated for intermediate states (aka KV cache). This class utilizes
- iteration-level scheduling and efficient memory management to maximize the
- serving throughput.
- The `LLM` class wraps this class for offline batched inference and the
- `AsyncAphrodite` class wraps this class for online serving.
- NOTE: The config arguments are derived from the `EngineArgs` class. For the
- comprehensive list of arguments, see `EngineArgs`.
- Args:
- model_config: The configuration related to the LLM model.
- cache_config: The configuration related to the KV cache memory
- management.
- parallel_config: The configuration related to distributed execution.
- scheduler_config: The configuration related to the request scheduler.
- device_config: The configuration related to the device.
- lora_config (Optional): The configuration related to serving multi-LoRA.
- vision_language_config (Optional): The configuration related to vision
- language models.
- speculative_config (Optional): The configuration related to speculative
- decoding.
- executor_class: The model executor class for managing distributed
- execution.
- log_stats: Whether to log statistics.
- """
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- speculative_config: Optional[SpeculativeConfig],
- executor_class: Type[ExecutorBase],
- log_stats: bool,
- ) -> None:
- logger.info(
- f"Initializing the Aphrodite Engine (v{aphrodite.__version__}) "
- "with the following config:\n"
- f"Model = {model_config.model!r}\n"
- f"Speculative Config = {speculative_config!r}\n"
- f"DataType = {model_config.dtype}\n"
- f"Model Load Format = {model_config.load_format}\n"
- f"Number of GPUs = {parallel_config.tensor_parallel_size}\n"
- f"Disable Custom All-Reduce = "
- f"{parallel_config.disable_custom_all_reduce}\n"
- f"Quantization Format = {model_config.quantization}\n"
- f"Context Length = {model_config.max_model_len}\n"
- f"Enforce Eager Mode = {model_config.enforce_eager}\n"
- f"KV Cache Data Type = {cache_config.cache_dtype}\n"
- f"KV Cache Params Path = {model_config.quantization_param_path}\n"
- f"Device = {device_config.device}")
- # TODO: Print more configs in debug mode.
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.vision_language_config = vision_language_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.speculative_config = speculative_config
- self.log_stats = log_stats
- self._verify_args()
- self._init_tokenizer()
- self.detokenizer = Detokenizer(self.tokenizer)
- self.seq_counter = Counter()
- self.model_executor = executor_class(
- model_config=model_config,
- cache_config=cache_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- lora_config=lora_config,
- vision_language_config=vision_language_config,
- speculative_config=speculative_config,
- )
- self._initialize_kv_caches()
- # Ping the tokenizer to ensure it is loaded if
- # it runs on a separate process.
- self.tokenizer.ping()
- # Create the scheduler.
- # NOTE: the cache_config here have been updated with the numbers of
- # GPU and CPU blocks, which are profiled in the distributed executor.
- self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
- # Metric Logging.
- if self.log_stats:
- self.stat_logger = StatLogger(
- local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
- labels=dict(model_name=model_config.model),
- )
- self.stat_logger.info("cache_config", self.cache_config)
- def _initialize_kv_caches(self) -> None:
- """Initialize the KV cache in the worker(s).
- The workers will determine the number of blocks in both the GPU cache
- and the swap CPU cache.
- """
- num_gpu_blocks, num_cpu_blocks = (
- self.model_executor.determine_num_available_blocks())
- if self.cache_config.num_gpu_blocks_override is not None:
- num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
- logger.info(f"Overriding {num_gpu_blocks=} with "
- f"{num_gpu_blocks_override=}")
- num_gpu_blocks = num_gpu_blocks_override
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
- self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
- @classmethod
- def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
- """Creates an LLM engine from the engine arguments."""
- # Create the engine configs.
- engine_config = engine_args.create_engine_config()
- # Initialize the cluster and specify the executor class.
- if engine_config.device_config.device_type == "neuron":
- from aphrodite.executor.neuron_executor import NeuronExecutor
- executor_class = NeuronExecutor
- elif engine_config.device_config.device_type == "cpu":
- from aphrodite.executor.cpu_executor import CPUExecutor
- executor_class = CPUExecutor
- elif engine_config.parallel_config.worker_use_ray:
- initialize_ray_cluster(engine_config.parallel_config)
- from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
- executor_class = RayGPUExecutor
- else:
- assert engine_config.parallel_config.world_size == 1, (
- "Ray is required if parallel_config.world_size > 1.")
- from aphrodite.executor.gpu_executor import GPUExecutor
- executor_class = GPUExecutor
- # Create the LLM engine.
- engine = cls(**engine_config.to_dict(),
- executor_class=executor_class,
- log_stats=not engine_args.disable_log_stats)
- return engine
- def __reduce__(self):
- # This is to ensure that the AphroditeEngine is not referenced in
- # the closure used to initialize Ray worker actors
- raise RuntimeError("AphroditeEngine should not be pickled!")
- def get_tokenizer(self) -> "PreTrainedTokenizer":
- return self.tokenizer.get_lora_tokenizer(None)
- def get_tokenizer_for_seq(self,
- sequence: Sequence) -> "PreTrainedTokenizer":
- return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
- def _init_tokenizer(self, **tokenizer_init_kwargs):
- init_kwargs = dict(
- tokenizer_id=self.model_config.tokenizer,
- enable_lora=bool(self.lora_config),
- max_num_seqs=self.scheduler_config.max_num_seqs,
- max_input_length=None,
- tokenizer_mode=self.model_config.tokenizer_mode,
- trust_remote_code=self.model_config.trust_remote_code,
- revision=self.model_config.tokenizer_revision,
- )
- init_kwargs.update(tokenizer_init_kwargs)
- self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
- self.parallel_config.tokenizer_pool_config, **init_kwargs)
- if len(self.get_tokenizer()) != self.model_config.get_vocab_size():
- logger.warning(
- f"The tokenizer's vocabulary size {len(self.get_tokenizer())}"
- f" does not match the model's vocabulary size "
- f"{self.model_config.get_vocab_size()}. This might "
- f"cause an error in decoding. Please change config.json "
- "to match the tokenizer's vocabulary size.")
- def _verify_args(self) -> None:
- self.model_config.verify_with_parallel_config(self.parallel_config)
- self.cache_config.verify_with_parallel_config(self.parallel_config)
- if self.lora_config:
- self.lora_config.verify_with_model_config(self.model_config)
- self.lora_config.verify_with_scheduler_config(
- self.scheduler_config)
- def encode_request(
- self,
- request_id: str,
- prompt: Optional[str],
- prompt_token_ids: Optional[List[int]] = None,
- lora_request: Optional[LoRARequest] = None,
- ):
- if prompt_token_ids is None:
- assert prompt is not None
- prompt_token_ids = self.tokenizer.encode(request_id=request_id,
- prompt=prompt,
- lora_request=lora_request)
- return prompt_token_ids
- def add_request(
- self,
- request_id: str,
- prompt: Optional[str],
- sampling_params: SamplingParams,
- prompt_token_ids: Optional[List[int]] = None,
- arrival_time: Optional[float] = None,
- lora_request: Optional[LoRARequest] = None,
- multi_modal_data: Optional[MultiModalData] = None,
- ) -> None:
- """Add a request to the engine's request pool.
- The request is added to the request pool and will be processed by the
- scheduler as `engine.step()` is called. The exact scheduling policy is
- determined by the scheduler.
- Args:
- request_id: The unique ID of the request.
- prompt: The prompt string. Can be None if prompt_token_ids is
- provided.
- sampling_params: The sampling parameters for text generation.
- prompt_token_ids: The token IDs of the prompt. If None, we
- use the tokenizer to convert the prompts to token IDs.
- arrival_time: The arrival time of the request. If None, we use
- the current monotonic time.
- multi_modal_data: The multimodal data for the request.
- Details:
- - Set arrival_time to the current time if it is None.
- - Set prompt_token_ids to the encoded prompt if it is None.
- - Create `best_of` number of :class:`~aphrodite.Sequence` objects.
- - Create a :class:`~aphrodite.SequenceGroup` object
- from the list of :class:`~aphrodite.Sequence`.
- - Add the :class:`~aphrodite.SequenceGroup` object to the scheduler.
- Example:
- >>> # initialize engine
- >>> engine = AphroditeEngine.from_engine_args(engine_args)
- >>> # set request arguments
- >>> example_prompt = "Who is the president of the United States?"
- >>> sampling_params = SamplingParams(temperature=0.0)
- >>> request_id = 0
- >>>
- >>> # add the request to the engine
- >>> engine.add_request(
- >>> str(request_id),
- >>> example_prompt,
- >>> SamplingParams(temperature=0.0))
- >>> # continue the request processing
- >>> ...
- """
- if lora_request is not None and not self.lora_config:
- raise ValueError(f"Got lora_request {lora_request} but LoRA is "
- "not enabled!")
- max_log_probs = self.get_model_config().max_log_probs
- if (sampling_params.logprobs
- and sampling_params.logprobs > max_log_probs) or (
- sampling_params.prompt_logprobs
- and sampling_params.prompt_logprobs > max_log_probs):
- raise ValueError(f"Cannot request more than "
- f"{max_log_probs} logprobs. "
- "Please increase the max_log_probs.")
- if arrival_time is None:
- arrival_time = time.monotonic()
- prompt_token_ids = self.encode_request(
- request_id=request_id,
- prompt=prompt,
- prompt_token_ids=prompt_token_ids,
- lora_request=lora_request,
- )
- # Create the sequences.
- block_size = self.cache_config.block_size
- seq_id = next(self.seq_counter)
- eos_token_id = self.tokenizer.get_lora_tokenizer(
- lora_request).eos_token_id
- seq = Sequence(
- seq_id,
- prompt,
- prompt_token_ids,
- block_size,
- eos_token_id,
- lora_request,
- )
- # Defensive copy of SamplingParams, which are used by the sampler,
- # this doesn't deep-copy LogitsProcessor objects
- sampling_params = sampling_params.clone()
- # Inject the eos token id into the sampling_params to support min_tokens
- # processing
- sampling_params.eos_token_id = seq.eos_token_id
- # Create the sequence group.
- seq_group = SequenceGroup(request_id, [seq], sampling_params,
- arrival_time, lora_request, multi_modal_data)
- # Add the sequence group to the scheduler.
- self.scheduler.add_seq_group(seq_group)
- def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
- """Aborts a request(s) with the given ID.
- Args:
- request_id: The ID(s) of the request to abort.
- Details:
- - Refer to the
- :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
- from class :class:`~aphrodite.processing.scheduler.Scheduler`.
- Example:
- >>> # initialize engine and add a request with request_id
- >>> request_id = str(0)
- >>> # abort the request
- >>> engine.abort_request(request_id)
- """
- self.scheduler.abort_seq_group(request_id)
- def get_model_config(self) -> ModelConfig:
- """Gets the model configuration."""
- return self.model_config
- def get_num_unfinished_requests(self) -> int:
- """Gets the number of unfinished requests."""
- return self.scheduler.get_num_unfinished_seq_groups()
- def has_unfinished_requests(self) -> bool:
- """Returns True if there are unfinished requests."""
- return self.scheduler.has_unfinished_seqs()
- def _check_beam_search_early_stopping(
- self,
- early_stopping: Union[bool, str],
- sampling_params: SamplingParams,
- best_running_seq: Sequence,
- current_worst_seq: Sequence,
- ) -> bool:
- assert sampling_params.use_beam_search
- length_penalty = sampling_params.length_penalty
- if early_stopping is True:
- return True
- current_worst_score = current_worst_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=current_worst_seq.eos_token_id,
- )
- if early_stopping is False:
- highest_attainable_score = best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id,
- )
- else:
- assert early_stopping == "never"
- if length_penalty > 0.0:
- # If length_penalty > 0.0, beam search will prefer longer
- # sequences. The highest attainable score calculation is
- # based on the longest possible sequence length in this case.
- max_possible_length = max(
- best_running_seq.get_prompt_len() +
- sampling_params.max_tokens,
- self.scheduler_config.max_model_len,
- )
- highest_attainable_score = (
- best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id,
- seq_len=max_possible_length,
- ))
- else:
- # Otherwise, beam search will prefer shorter sequences. The
- # highest attainable score calculation is based on the current
- # sequence length.
- highest_attainable_score = (
- best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id,
- ))
- return current_worst_score >= highest_attainable_score
- def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
- outputs: SequenceGroupOutput) -> None:
- # Process prompt logprobs
- prompt_logprobs = outputs.prompt_logprobs
- if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
- self.detokenizer.decode_prompt_logprobs_inplace(
- seq_group, prompt_logprobs)
- seq_group.prompt_logprobs = prompt_logprobs
- # Process samples
- samples = outputs.samples
- parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
- existing_finished_seqs = seq_group.get_finished_seqs()
- parent_child_dict = {
- parent_seq.seq_id: []
- for parent_seq in parent_seqs
- }
- for sample in samples:
- parent_child_dict[sample.parent_seq_id].append(sample)
- # List of (child, parent)
- child_seqs: List[Tuple[Sequence, Sequence]] = []
- # Process the child samples for each parent sequence
- for parent in parent_seqs:
- child_samples: List[SequenceOutput] = parent_child_dict[
- parent.seq_id]
- if len(child_samples) == 0:
- # This parent sequence has no children samples. Remove
- # the parent sequence from the sequence group since it will
- # not be used in the future iterations.
- parent.status = SequenceStatus.FINISHED_ABORTED
- seq_group.remove(parent.seq_id)
- self.scheduler.free_seq(parent)
- continue
- # Fork the parent sequence if there are multiple child samples.
- for child_sample in child_samples[:-1]:
- new_child_seq_id = next(self.seq_counter)
- child = parent.fork(new_child_seq_id)
- child.append_token_id(child_sample.output_token,
- child_sample.logprobs)
- child.persistent_data = child_sample.persistent_data
- child_seqs.append((child, parent))
- # Continue the parent sequence for the last child sample.
- # We reuse the parent sequence here to reduce redundant memory
- # copies, especially when using non-beam search sampling methods.
- last_child_sample = child_samples[-1]
- parent.append_token_id(last_child_sample.output_token,
- last_child_sample.logprobs)
- parent.persistent_data = last_child_sample.persistent_data
- child_seqs.append((parent, parent))
- for seq, _ in child_seqs:
- if seq_group.sampling_params.detokenize:
- new_char_count = self.detokenizer.decode_sequence_inplace(
- seq, seq_group.sampling_params)
- else:
- new_char_count = 0
- self._check_stop(seq, new_char_count, seq_group.sampling_params)
- # Non-beam search case
- if not seq_group.sampling_params.use_beam_search:
- # For newly created child sequences, add them to the sequence group
- # and fork them in block manager if they are not finished.
- for seq, parent in child_seqs:
- if seq is not parent:
- seq_group.add(seq)
- if not seq.is_finished():
- self.scheduler.fork_seq(parent, seq)
- # Free the finished and selected parent sequences' memory in block
- # manager. Keep them in the sequence group as candidate output.
- # NOTE: we need to fork the new sequences before freeing the
- # old sequences.
- for seq, parent in child_seqs:
- if seq is parent and seq.is_finished():
- self.scheduler.free_seq(seq)
- return
- # Beam search case
- # Select the child sequences to keep in the sequence group.
- selected_child_seqs = []
- unselected_child_seqs = []
- beam_width = seq_group.sampling_params.best_of
- length_penalty = seq_group.sampling_params.length_penalty
- # Select the newly finished sequences with the highest scores
- # to replace existing finished sequences.
- # Tuple of (seq, parent, is_new)
- existing_finished_seqs = [(seq, None, False)
- for seq in existing_finished_seqs]
- new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
- if seq.is_finished()]
- all_finished_seqs = existing_finished_seqs + new_finished_seqs
- # Sort the finished sequences by their scores.
- all_finished_seqs.sort(
- key=lambda x: x[0].get_beam_search_score(
- length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
- reverse=True,
- )
- for seq, parent, is_new in all_finished_seqs[:beam_width]:
- if is_new:
- # A newly generated child sequence finishes and has a high
- # score, so we will add it into the sequence group.
- selected_child_seqs.append((seq, parent))
- for seq, parent, is_new in all_finished_seqs[beam_width:]:
- if is_new:
- # A newly generated child sequence finishes but has a low
- # score, so we will not add it into the sequence group.
- # Additionally, if this sequence is a continuation of a
- # parent sequence, we will need remove the parent sequence
- # from the sequence group.
- unselected_child_seqs.append((seq, parent))
- else:
- # An existing finished sequence has a low score, so we will
- # remove it from the sequence group.
- seq_group.remove(seq.seq_id)
- # select the top beam_width sequences from the running
- # sequences for the next iteration to continue the beam
- # search.
- running_child_seqs = [(seq, parent) for seq, parent in child_seqs
- if not seq.is_finished()]
- # Sort the running sequences by their scores.
- running_child_seqs.sort(
- key=lambda x: x[0].get_beam_search_score(
- length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
- reverse=True,
- )
- # Check if we can stop the beam search.
- if len(running_child_seqs) == 0:
- # No running sequences, stop the beam search.
- stop_beam_search = True
- elif len(all_finished_seqs) < beam_width:
- # Not enough finished sequences, continue the beam search.
- stop_beam_search = False
- else:
- # Check the early stopping criteria
- best_running_seq = running_child_seqs[0][0]
- current_worst_seq = all_finished_seqs[beam_width - 1][0]
- stop_beam_search = self._check_beam_search_early_stopping(
- seq_group.sampling_params.early_stopping,
- seq_group.sampling_params,
- best_running_seq,
- current_worst_seq,
- )
- if stop_beam_search:
- # Stop the beam search and remove all the running sequences from
- # the sequence group.
- unselected_child_seqs.extend(running_child_seqs)
- else:
- # Continue the beam search and select the top beam_width sequences
- # to continue the beam search.
- selected_child_seqs.extend(running_child_seqs[:beam_width])
- # The remaining running sequences will not be used in the next
- # iteration. Again, if these sequences are continuations of
- # parent sequences, we will need to remove the parent sequences
- # from the sequence group.
- unselected_child_seqs.extend(running_child_seqs[beam_width:])
- # For newly created child sequences, add them to the sequence group
- # and fork them in block manager if they are not finished.
- for seq, parent in selected_child_seqs:
- if seq is not parent:
- seq_group.add(seq)
- if not seq.is_finished():
- self.scheduler.fork_seq(parent, seq)
- # Free the finished and selected parent sequences' memory in block
- # manager. Keep them in the sequence group as candidate output.
- for seq, parent in selected_child_seqs:
- if seq is parent and seq.is_finished():
- self.scheduler.free_seq(seq)
- # Remove the unselected parent sequences from the sequence group and
- # free their memory in block manager.
- for seq, parent in unselected_child_seqs:
- if seq is parent:
- # Remove the parent sequence if it is not selected for next
- # iteration
- seq_group.remove(seq.seq_id)
- self.scheduler.free_seq(seq)
- def _process_model_outputs(
- self, output: SamplerOutput,
- scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
- now = time.time()
- # Update the scheduled sequence groups with the model outputs.
- scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
- for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
- seq_group = scheduled_seq_group.seq_group
- seq_group.update_num_computed_tokens(
- scheduled_seq_group.token_chunk_size)
- # If uncomputed tokens > 0, it means prefill is chunked.
- # We don't need to process outputs in that case.
- if seq_group.get_num_uncomputed_tokens() == 0:
- self._process_sequence_group_outputs(seq_group, outputs)
- # Free the finished sequence groups.
- self.scheduler.free_finished_seq_groups()
- # Create the outputs.
- request_outputs: List[RequestOutput] = []
- for scheduled_seq_group in scheduled_seq_groups:
- seq_group = scheduled_seq_group.seq_group
- seq_group.maybe_set_first_token_time(now)
- request_output = RequestOutput.from_seq_group(seq_group)
- request_outputs.append(request_output)
- for seq_group in scheduler_outputs.ignored_seq_groups:
- request_output = RequestOutput.from_seq_group(seq_group)
- request_outputs.append(request_output)
- # Log stats.
- if self.log_stats:
- self.stat_logger.log(self._get_stats(scheduler_outputs))
- return request_outputs
- def step(self) -> List[RequestOutput]:
- """Performs one decoding iteration and returns newly generated results.
- .. figure:: https://i.imgur.com/sv2HssD.png
- :alt: Overview of the step function
- :align: center
- Overview of the step function.
- Details:
- - Step 1: Schedules the sequences to be executed in the next
- iteration and the token blocks to be swapped in/out/copy.
- - Depending on the scheduling policy,
- sequences may be `preempted/reordered`.
- - A Sequence Group (SG) refer to a group of sequences
- that are generated from the same prompt.
- - Step 2: Calls the distributed executor to execute the model.
- - Step 3: Processes the model output. This mainly includes:
- - Decodes the relevant outputs.
- - Updates the scheduled sequence groups with model outputs
- based on its `sampling parameters` (`use_beam_search` or not).
- - Frees the finished sequence groups.
- - Finally, it creates and returns the newly generated results.
- Example:
- >>> # Please see the example/ folder for more detailed examples.
- >>>
- >>> # initialize engine and request arguments
- >>> engine = AphroditeEngine.from_engine_args(engine_args)
- >>> example_inputs = [(0, "What is LLM?",
- >>> SamplingParams(temperature=0.0))]
- >>>
- >>> # Start the engine with an event loop
- >>> while True:
- >>> if example_inputs:
- >>> req_id, prompt, sampling_params = example_inputs.pop(0)
- >>> engine.add_request(str(req_id), prompt, sampling_params)
- >>>
- >>> # continue the request processing
- >>> request_outputs = engine.step()
- >>> for request_output in request_outputs:
- >>> if request_output.finished:
- >>> # return or show the request output
- >>>
- >>> if not (engine.has_unfinished_requests() or example_inputs):
- >>> break
- """
- seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
- if not scheduler_outputs.is_empty():
- output = self.model_executor.execute_model(
- seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
- scheduler_outputs.blocks_to_swap_out,
- scheduler_outputs.blocks_to_copy)
- else:
- output = []
- return self._process_model_outputs(output, scheduler_outputs)
- def do_log_stats(self) -> None:
- """Forced log when no requests active."""
- if self.log_stats:
- self.stat_logger.log(self._get_stats(scheduler_outputs=None))
- def _get_stats(self,
- scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
- """Get Stats to be Logged to Prometheus."""
- now = time.monotonic()
- # KV Cache Usage in %.
- num_total_gpu = self.cache_config.num_gpu_blocks
- num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
- gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
- num_total_cpu = self.cache_config.num_cpu_blocks
- cpu_cache_usage = 0.0
- if num_total_cpu > 0:
- num_free_cpu = (
- self.scheduler.block_manager.get_num_free_cpu_blocks())
- cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
- # Scheduler State
- num_running = len(self.scheduler.running)
- num_swapped = len(self.scheduler.swapped)
- num_waiting = len(self.scheduler.waiting)
- # Iteration stats if we have scheduler output.
- num_prompt_tokens = 0
- num_generation_tokens = 0
- time_to_first_tokens = []
- time_per_output_tokens = []
- time_e2e_requests = []
- if scheduler_outputs is not None:
- prompt_run = scheduler_outputs.num_prefill_groups > 0
- # Number of Tokens.
- if prompt_run:
- num_prompt_tokens = sum(
- len(scheduled_seq_group.seq_group.prompt_token_ids)
- for scheduled_seq_group in
- scheduler_outputs.scheduled_seq_groups)
- num_generation_tokens = sum(
- scheduled_seq_group.seq_group.num_seqs()
- for scheduled_seq_group in
- scheduler_outputs.scheduled_seq_groups)
- else:
- num_generation_tokens = scheduler_outputs.num_batched_tokens
- # Latency Timings.
- time_last_iters = []
- for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
- seq_group = scheduled_seq_group.seq_group
- # Time since last token.
- # (n.b. updates seq_group.metrics.last_token_time)
- time_last_iters.append(seq_group.get_last_latency(now))
- # Time since arrival for all finished requests.
- if seq_group.is_finished():
- time_e2e_requests.append(now -
- seq_group.metrics.arrival_time)
- time_to_first_tokens = time_last_iters if prompt_run else []
- time_per_output_tokens = [] if prompt_run else time_last_iters
- return Stats(
- now=now,
- num_running=num_running,
- num_swapped=num_swapped,
- num_waiting=num_waiting,
- gpu_cache_usage=gpu_cache_usage,
- cpu_cache_usage=cpu_cache_usage,
- num_prompt_tokens=num_prompt_tokens,
- num_generation_tokens=num_generation_tokens,
- time_to_first_tokens=time_to_first_tokens,
- time_per_output_tokens=time_per_output_tokens,
- time_e2e_requests=time_e2e_requests,
- )
- def _check_stop(self, seq: Sequence, new_char_count: int,
- sampling_params: SamplingParams) -> None:
- """Stop the finished sequences.
- new_char_count is the number of chars added to the
- sequence's output text for the newly generated token
- """
- # Check if the minimum number of tokens has been generated yet;
- # skip the stop string/token checks if not
- if seq.get_output_len() < sampling_params.min_tokens:
- return
- # Check if the sequence has generated the EOS token.
- if ((not sampling_params.ignore_eos)
- and seq.get_last_token_id() == seq.eos_token_id):
- seq.status = SequenceStatus.FINISHED_STOPPED
- return
- # Check if a stop token was encountered.
- # This assumes a single token produced per step.
- last_token_id = seq.get_last_token_id()
- if last_token_id in sampling_params.stop_token_ids:
- if new_char_count and (
- not sampling_params.include_stop_str_in_output):
- # Remove last token
- seq.output_text = seq.output_text[:-new_char_count]
- seq.status = SequenceStatus.FINISHED_STOPPED
- seq.stop_reason = last_token_id
- return
- # Check if any stop strings are matched.
- stop_str = self._check_stop_strings(seq, new_char_count,
- sampling_params)
- if stop_str is not None:
- seq.status = SequenceStatus.FINISHED_STOPPED
- seq.stop_reason = stop_str
- return
- # Check if the sequence has reached max_model_len.
- if seq.get_len() > self.scheduler_config.max_model_len:
- seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
- return
- # Check if the sequence has reached max_tokens.
- if seq.get_output_len() == sampling_params.max_tokens:
- seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
- return
- @staticmethod
- def _check_stop_strings(seq: Sequence, new_char_count: int,
- sampling_params: SamplingParams) -> Optional[str]:
- """Check if any stop strings are matched and truncate sequence
- output text accordingly.
- Returns the stop string if matched or else None.
- """
- if not new_char_count:
- return None
- for stop_str in sampling_params.stop:
- stop_string_len = len(stop_str)
- # Avoid searching already-searched text.
- stop_index = seq.output_text.find(
- stop_str, -new_char_count - stop_string_len)
- if stop_index == -1:
- continue
- if sampling_params.include_stop_str_in_output:
- # Truncate to end of stop string.
- stop_index += stop_string_len
- if stop_index >= len(seq.output_text):
- # No truncation required.
- return stop_str
- # Truncate the output text to either the beginning
- # or end of the stop string.
- seq.output_text = seq.output_text[:stop_index]
- return stop_str
- return None
- def add_lora(self, lora_request: LoRARequest) -> bool:
- return self.model_executor.add_lora(lora_request)
- def remove_lora(self, lora_id: int) -> bool:
- return self.model_executor.remove_lora(lora_id)
- def list_loras(self) -> List[int]:
- return self.model_executor.list_loras()
- def check_health(self) -> None:
- self.model_executor.check_health()
- setup_logger()
|