multi_step.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import functools
  2. from typing import Callable, List
  3. from transformers import PreTrainedTokenizer
  4. from aphrodite.common.logger import log_once
  5. from aphrodite.common.sampling_params import SamplingParams
  6. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  7. SequenceGroupOutput, SequenceOutput,
  8. SequenceStatus)
  9. from aphrodite.common.utils import Counter
  10. from aphrodite.engine.output_processor.interfaces import (
  11. SequenceGroupOutputProcessor)
  12. from aphrodite.engine.output_processor.single_step import (
  13. single_step_process_prompt_logprob)
  14. from aphrodite.engine.output_processor.stop_checker import StopChecker
  15. from aphrodite.processing.scheduler import Scheduler
  16. from aphrodite.transformers_utils.detokenizer import Detokenizer
  17. class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
  18. """SequenceGroupOutputProcessor which handles logic related to
  19. detokenization and stopping conditions. It specializes to "multi-step
  20. decoding", where Aphrodite's worker may generate multiple tokens per
  21. invocation. This is currently mutually exclusive with advanced sampling
  22. techniques like beam search, which motivates the separation of this logic
  23. from the single step output processor.
  24. This class is responsible for things such as correctly appending all new
  25. token ids to their sequence, detokenizing new token ids, truncating new
  26. output tokens after an eos token, and correctly handling the case where the
  27. number of new output tokens per sequence differs in a single batch.
  28. """
  29. def __init__(
  30. self,
  31. detokenizer: Detokenizer,
  32. scheduler: List[Scheduler],
  33. seq_counter: Counter,
  34. get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
  35. stop_checker: StopChecker,
  36. ):
  37. self.detokenizer = detokenizer
  38. self.scheduler = scheduler
  39. self.seq_counter = seq_counter
  40. self.get_tokenizer_for_seq = get_tokenizer_for_seq
  41. self.stop_checker = stop_checker
  42. def process_prompt_logprob(self, seq_group: SequenceGroup,
  43. outputs: List[SequenceGroupOutput]) -> None:
  44. """Process prompt logprobs associated with each step of a multi-step-
  45. scheduled computation.
  46. Args:
  47. seq_group: the outputs are associated with this :class:`SequenceGroup`
  48. outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
  49. """
  50. for output in outputs:
  51. # Concatenate single-step prompt logprob processing results.
  52. single_step_process_prompt_logprob(self, seq_group, output)
  53. @staticmethod
  54. @functools.lru_cache()
  55. def _log_prompt_logprob_unsupported_warning_once():
  56. log_once(
  57. level="WARNING",
  58. message="Prompt logprob is not supported by multi step workers. "
  59. "(e.g., speculative decode uses multi step workers).")
  60. def process_outputs(self,
  61. sequence_group: SequenceGroup,
  62. outputs: List[SequenceGroupOutput],
  63. is_async: bool = False) -> None:
  64. """Append new tokens in the outputs to sequences in the sequence group.
  65. This only supports sequence groups of size 1. It supports greater than
  66. one new token per sequence.
  67. This applies logic like stop condition checking and detokenization.
  68. It also handles cases where there are tokens emitted after
  69. the EOS token.
  70. is_async - Indicates whether this postprocessor runs in
  71. parallel with the GPU forward pass and is processing
  72. tokens from the previous step. If this is true, then
  73. no tokens need to be appended since it is already done
  74. externally (before the next schedule() call)
  75. """
  76. # TODO: Add support for async if necessary
  77. assert not is_async
  78. # Sequences can be in RUNNING or FINISHED_ABORTED state
  79. # once scheduled, as a sequence is moved to FINSIHED_ABORTED
  80. # if a client disconnects from the api server.
  81. seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
  82. if seqs is None:
  83. seqs = sequence_group.get_seqs(
  84. status=SequenceStatus.FINISHED_ABORTED)
  85. assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
  86. assert len(seqs) == 1, (
  87. "Beam search not supported in multi-step decoding.")
  88. seq = seqs[0]
  89. # Since there's only one sequence per sequence group, we can take the
  90. # first sample.
  91. samples = [output.samples[0] for output in outputs]
  92. # -1 means the output token is not valid (eg. due to spec decode
  93. # rejecting tokens).
  94. valid_samples = [
  95. sample for sample in samples if sample.output_token != -1
  96. ]
  97. assert valid_samples
  98. self._process_seq_outputs(seq, valid_samples,
  99. sequence_group.sampling_params)
  100. def _process_seq_outputs(self, seq: Sequence,
  101. valid_samples: List[SequenceOutput],
  102. sampling_params: SamplingParams) -> None:
  103. output_token_ids = [sample.output_token for sample in valid_samples]
  104. output_logprobs = [sample.logprobs for sample in valid_samples]
  105. # Truncate to max_tokens if necessary.
  106. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
  107. len(output_token_ids))
  108. if remaining_tokens < 0:
  109. valid_samples = valid_samples[:remaining_tokens]
  110. output_token_ids = output_token_ids[:remaining_tokens]
  111. # Truncate any tokens after EOS. This is required as spec decode
  112. # generates a fixed number of tokens without evaluating stopping
  113. # conditions within the block. This can cause an eos token to be
  114. # unintentionally ignored.
  115. if not sampling_params.ignore_eos:
  116. eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
  117. # Avoiding .index calls as exception throwing in the happy path
  118. # is expensive.
  119. for i in range(len(output_token_ids)):
  120. if output_token_ids[i] == eos_token_id:
  121. output_token_ids = output_token_ids[:i + 1]
  122. valid_samples = valid_samples[:i + 1]
  123. break
  124. # Incrementally append tokens to the sequence, as if we had only one new
  125. # token.
  126. for output_token_id, output_logprob in zip(output_token_ids,
  127. output_logprobs):
  128. seq.append_token_id(
  129. token_id=output_token_id,
  130. logprobs=output_logprob,
  131. )
  132. new_char_count = 0
  133. if sampling_params.detokenize:
  134. new_char_count = self.detokenizer.decode_sequence_inplace(
  135. seq, sampling_params)
  136. self.stop_checker.maybe_stop_sequence(
  137. seq,
  138. new_char_count=new_char_count,
  139. sampling_params=sampling_params)
  140. if seq.is_finished():
  141. break