123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- from typing import Dict, List, Tuple, Union
- from aphrodite.common.config import SchedulerConfig
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import (Sequence, SequenceGroup,
- SequenceGroupOutput, SequenceOutput,
- SequenceStatus)
- from aphrodite.common.utils import Counter
- from aphrodite.engine.output_processor.interfaces import (
- SequenceGroupOutputProcessor)
- from aphrodite.engine.output_processor.stop_checker import StopChecker
- from aphrodite.processing.scheduler import Scheduler
- from aphrodite.transformers_utils.detokenizer import Detokenizer
- class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
- """SequenceGroupOutputProcessor which handles "output processing" logic,
- which happens after the model returns generated token ids and before
- scheduling of the next batch. Output processing logic includes
- detokenization, and determining if a sequence is finished (e.g. via max len
- or eos token).
- The SingleStepOutputProcessor is specialized to the case where the model
- emits at most a single token per invocation, which precludes configurations
- such as speculative decoding or multi-step decoding. This enables beam
- search sampling, which requires forking/finishing/freeing sequences in a way
- that is currently difficult to schedule multiple steps ahead of time.
- """
- def __init__(
- self,
- scheduler_config: SchedulerConfig,
- detokenizer: Detokenizer,
- scheduler: List[Scheduler],
- seq_counter: Counter,
- stop_checker: StopChecker,
- ):
- self.scheduler_config = scheduler_config
- self.detokenizer = detokenizer
- self.scheduler = scheduler
- self.seq_counter = seq_counter
- self.stop_checker = stop_checker
- def process_outputs(self, sequence_group: SequenceGroup,
- outputs: List[SequenceGroupOutput]) -> None:
- """Append all new tokens to sequences in the sequence group. Fork any
- surviving beam candidates; free any unsurviving ones.
- Invokes detokenizer to detokenize new tokens, and also marks sequences
- as finished if they meet stop conditions.
- """
- assert (len(outputs) == 1
- ), f"{type(self)} does not support multiple outputs per step"
- return self._process_sequence_group_outputs(sequence_group, outputs[0])
- def process_prompt_logprob(self, seq_group: SequenceGroup,
- outputs: List[SequenceGroupOutput]) -> None:
- assert len(outputs) == 1, ("Single step should only has 1 output.")
- output = outputs[0]
- prompt_logprobs = output.prompt_logprobs
- # If this is the first (or only) "chunk" of the prefill, we need
- # to prepend None to the list of prompt logprobs. The reason for this
- # is that for N prompt tokens, the Sampler will generate N-1 total
- # prompt logprobs during prefill since the token at idx 0 will not
- # have a logprob associated with it.
- if prompt_logprobs is not None:
- if not seq_group.prompt_logprobs:
- prompt_logprobs = [None] + prompt_logprobs
- seq_group.prompt_logprobs = []
- if seq_group.sampling_params.detokenize and self.detokenizer:
- self.detokenizer.decode_prompt_logprobs_inplace(
- seq_group,
- prompt_logprobs,
- position_offset=len(seq_group.prompt_logprobs))
- seq_group.prompt_logprobs.extend(prompt_logprobs)
- def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
- outputs: SequenceGroupOutput) -> None:
- sampling_params = seq_group.sampling_params
- if sampling_params.n == 1 and not sampling_params.use_beam_search:
- # only have one output sample
- sample = outputs.samples[0]
- # only have one sequence
- seq = seq_group.seqs[0]
- seq.append_token_id(sample.output_token, sample.logprobs)
- if sampling_params.detokenize and self.detokenizer:
- new_char_count = self.detokenizer.decode_sequence_inplace(
- seq, sampling_params)
- else:
- new_char_count = 0
- self.stop_checker.maybe_stop_sequence(
- seq,
- new_char_count,
- sampling_params,
- lora_req=seq_group.lora_request,
- )
- if seq.is_finished():
- for scheduler in self.scheduler:
- scheduler.free_seq(seq)
- return
- # Process samples
- samples = outputs.samples
- parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
- existing_finished_seqs = seq_group.get_finished_seqs()
- parent_child_dict: Dict[int, List[SequenceOutput]] = {
- parent_seq.seq_id: []
- for parent_seq in parent_seqs
- }
- for sample in samples:
- # Guard against a KeyError which can occur if the request was
- # aborted while the output was generated
- if (child_list :=
- parent_child_dict.get(sample.parent_seq_id)) is not None:
- child_list.append(sample)
- # List of (child, parent)
- child_seqs: List[Tuple[Sequence, Sequence]] = []
- # Process the child samples for each parent sequence
- for parent in parent_seqs:
- child_samples: List[SequenceOutput] = parent_child_dict[
- parent.seq_id]
- if len(child_samples) == 0:
- # This parent sequence has no children samples. Remove
- # the parent sequence from the sequence group since it will
- # not be used in the future iterations.
- parent.status = SequenceStatus.FINISHED_ABORTED
- seq_group.remove(parent.seq_id)
- for scheduler in self.scheduler:
- scheduler.free_seq(parent)
- continue
- # Fork the parent sequence if there are multiple child samples.
- for child_sample in child_samples[:-1]:
- new_child_seq_id: int = next(self.seq_counter)
- child = parent.fork(new_child_seq_id)
- child.append_token_id(child_sample.output_token,
- child_sample.logprobs)
- child_seqs.append((child, parent))
- # Continue the parent sequence for the last child sample.
- # We reuse the parent sequence here to reduce redundant memory
- # copies, especially when using non-beam search sampling methods.
- last_child_sample = child_samples[-1]
- parent.append_token_id(last_child_sample.output_token,
- last_child_sample.logprobs)
- child_seqs.append((parent, parent))
- for seq, _ in child_seqs:
- if sampling_params.detokenize and self.detokenizer:
- new_char_count = self.detokenizer.decode_sequence_inplace(
- seq, sampling_params)
- else:
- new_char_count = 0
- self.stop_checker.maybe_stop_sequence(
- seq,
- new_char_count,
- sampling_params,
- lora_req=seq_group.lora_request,
- )
- # Non-beam search case
- if not sampling_params.use_beam_search:
- # For newly created child sequences, add them to the sequence group
- # and fork them in block manager if they are not finished.
- for seq, parent in child_seqs:
- if seq is not parent:
- seq_group.add(seq)
- if not seq.is_finished():
- for scheduler in self.scheduler:
- scheduler.fork_seq(parent, seq)
- # Free the finished and selected parent sequences' memory in block
- # manager. Keep them in the sequence group as candidate output.
- # NOTE: we need to fork the new sequences before freeing the
- # old sequences.
- for seq, parent in child_seqs:
- if seq is parent and seq.is_finished():
- for scheduler in self.scheduler:
- scheduler.free_seq(seq)
- return
- # Beam search case
- # Select the child sequences to keep in the sequence group.
- selected_child_seqs = []
- unselected_child_seqs = []
- beam_width = sampling_params.best_of
- length_penalty = sampling_params.length_penalty
- # Select the newly finished sequences with the highest scores
- # to replace existing finished sequences.
- # Tuple of (seq, parent, is_new)
- existing_finished_seqs = [(seq, None, False)
- for seq in existing_finished_seqs]
- new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
- if seq.is_finished()]
- all_finished_seqs = existing_finished_seqs + new_finished_seqs
- # Sort the finished sequences by their scores.
- all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
- length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
- reverse=True)
- for seq, parent, is_new in all_finished_seqs[:beam_width]:
- if is_new:
- # A newly generated child sequence finishes and has a high
- # score, so we will add it into the sequence group.
- selected_child_seqs.append((seq, parent))
- for seq, parent, is_new in all_finished_seqs[beam_width:]:
- if is_new:
- # A newly generated child sequence finishes but has a low
- # score, so we will not add it into the sequence group.
- # Additionally, if this sequence is a continuation of a
- # parent sequence, we will need remove the parent sequence
- # from the sequence group.
- unselected_child_seqs.append((seq, parent))
- else:
- # An existing finished sequence has a low score, so we will
- # remove it from the sequence group.
- seq_group.remove(seq.seq_id)
- # select the top beam_width sequences from the running
- # sequences for the next iteration to continue the beam
- # search.
- running_child_seqs = [(seq, parent) for seq, parent in child_seqs
- if not seq.is_finished()]
- # Sort the running sequences by their scores.
- running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
- length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
- reverse=True)
- # Check if we can stop the beam search.
- if len(running_child_seqs) == 0:
- # No running sequences, stop the beam search.
- stop_beam_search = True
- elif len(all_finished_seqs) < beam_width:
- # Not enough finished sequences, continue the beam search.
- stop_beam_search = False
- else:
- # Check the early stopping criteria
- best_running_seq = running_child_seqs[0][0]
- current_worst_seq = all_finished_seqs[beam_width - 1][0]
- stop_beam_search = self._check_beam_search_early_stopping(
- sampling_params.early_stopping, sampling_params,
- best_running_seq, current_worst_seq)
- if stop_beam_search:
- # Stop the beam search and remove all the running sequences from
- # the sequence group.
- unselected_child_seqs.extend(running_child_seqs)
- else:
- # Continue the beam search and select the top beam_width sequences
- # to continue the beam search.
- selected_child_seqs.extend(running_child_seqs[:beam_width])
- # The remaining running sequences will not be used in the next
- # iteration. Again, if these sequences are continuations of
- # parent sequences, we will need to remove the parent sequences
- # from the sequence group.
- unselected_child_seqs.extend(running_child_seqs[beam_width:])
- # For newly created child sequences, add them to the sequence group
- # and fork them in block manager if they are not finished.
- for seq, parent in selected_child_seqs:
- if seq is not parent:
- seq_group.add(seq)
- if not seq.is_finished():
- for scheduler in self.scheduler:
- scheduler.fork_seq(parent, seq)
- # Free the finished and selected parent sequences' memory in block
- # manager. Keep them in the sequence group as candidate output.
- for seq, parent in selected_child_seqs:
- if seq is parent and seq.is_finished():
- for scheduler in self.scheduler:
- scheduler.free_seq(seq)
- # Remove the unselected parent sequences from the sequence group and
- # free their memory in block manager.
- for seq, parent in unselected_child_seqs:
- if seq is parent:
- # Remove the parent sequence if it is not selected for next
- # iteration
- seq_group.remove(seq.seq_id)
- for scheduler in self.scheduler:
- scheduler.free_seq(seq)
- def _check_beam_search_early_stopping(
- self,
- early_stopping: Union[bool, str],
- sampling_params: SamplingParams,
- best_running_seq: Sequence,
- current_worst_seq: Sequence,
- ) -> bool:
- assert sampling_params.use_beam_search
- length_penalty = sampling_params.length_penalty
- if early_stopping is True:
- return True
- current_worst_score = current_worst_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=current_worst_seq.eos_token_id)
- if early_stopping is False:
- highest_attainable_score = best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id)
- else:
- assert early_stopping == "never"
- if length_penalty > 0.0:
- # If length_penalty > 0.0, beam search will prefer longer
- # sequences. The highest attainable score calculation is
- # based on the longest possible sequence length in this case.
- max_possible_length = max(
- best_running_seq.get_prompt_len() +
- sampling_params.max_tokens,
- self.scheduler_config.max_model_len)
- highest_attainable_score = (
- best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id,
- seq_len=max_possible_length))
- else:
- # Otherwise, beam search will prefer shorter sequences. The
- # highest attainable score calculation is based on the current
- # sequence length.
- highest_attainable_score = (
- best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id))
- return current_worst_score >= highest_attainable_score
|