import functools from typing import Callable, List from transformers import PreTrainedTokenizer from aphrodite.common.logger import log_once from aphrodite.common.sampling_params import SamplingParams from aphrodite.common.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from aphrodite.common.utils import Counter from aphrodite.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) from aphrodite.engine.output_processor.stop_checker import StopChecker from aphrodite.processing.scheduler import Scheduler from aphrodite.transformers_utils.detokenizer import Detokenizer class MultiStepOutputProcessor(SequenceGroupOutputProcessor): """SequenceGroupOutputProcessor which handles logic related to detokenization and stopping conditions. It specializes to "multi-step decoding", where Aphrodite's worker may generate multiple tokens per invocation. This is currently mutually exclusive with advanced sampling techniques like beam search, which motivates the separation of this logic from the single step output processor. This class is responsible for things such as correctly appending all new token ids to their sequence, detokenizing new token ids, truncating new output tokens after an eos token, and correctly handling the case where the number of new output tokens per sequence differs in a single batch. """ def __init__( self, detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: StopChecker, ): self.detokenizer = detokenizer self.scheduler = scheduler self.seq_counter = seq_counter self.get_tokenizer_for_seq = get_tokenizer_for_seq self.stop_checker = stop_checker def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: # TODO: Prompt logprob currently not implemented in multi step # workers. self._log_prompt_logprob_unsupported_warning_once() @staticmethod @functools.lru_cache() def _log_prompt_logprob_unsupported_warning_once(): log_once( level="WARNING", message="Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than one new token per sequence. This applies logic like stop condition checking and detokenization, including freeing finished sequences. It also handles cases where there are tokens emitted after the EOS token. """ seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" assert len(seqs) == 1, ( "Beam search not supported in multi-step decoding.") seq = seqs[0] # Since there's only one sequence per sequence group, we can take the # first sample. samples = [output.samples[0] for output in outputs] # -1 means the output token is not valid (eg. due to spec decode # rejecting tokens). valid_samples = [ sample for sample in samples if sample.output_token != -1 ] assert valid_samples self._process_seq_outputs(seq, valid_samples, sequence_group.sampling_params) def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + len(output_token_ids)) if remaining_tokens < 0: valid_samples = valid_samples[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens] # Truncate any tokens after EOS. This is required as spec decode # generates a fixed number of tokens without evaluating stopping # conditions within the block. This can cause an eos token to be # unintentionally ignored. if not sampling_params.ignore_eos: eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id # Avoiding .index calls as exception throwing in the happy path # is expensive. for i in range(len(output_token_ids)): if output_token_ids[i] == eos_token_id: output_token_ids = output_token_ids[:i + 1] valid_samples = valid_samples[:i + 1] break # Incrementally append tokens to the sequence, as if we had only one new # token. for output_token_id, output_logprob in zip(output_token_ids, output_logprobs): seq.append_token_id( token_id=output_token_id, logprobs=output_logprob, ) new_char_count = 0 if sampling_params.detokenize: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) self.stop_checker.maybe_stop_sequence( seq, new_char_count=new_char_count, sampling_params=sampling_params) if seq.is_finished(): break if seq.is_finished(): for scheduler in self.scheduler: scheduler.free_seq(seq)