multi_step.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from typing import Callable, Iterable, List
  2. from transformers import PreTrainedTokenizer
  3. from aphrodite.processing.scheduler import Scheduler
  4. from aphrodite.engine.output_processor.interfaces import (
  5. SequenceGroupOutputProcessor)
  6. from aphrodite.engine.output_processor.stop_checker import StopChecker
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.common.sequence import (Logprob, Sequence, SequenceGroup,
  9. SequenceGroupOutput, SequenceOutput,
  10. SequenceStatus)
  11. from aphrodite.transformers_utils.detokenizer import Detokenizer
  12. class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
  13. """SequenceGroupOutputProcessor which handles logic related to
  14. detokenization and stopping conditions. It specializes to "multi-step
  15. decoding", where Aphrodite's worker may generate multiple tokens per
  16. invocation.
  17. This is currently mutually exclusive with advanced sampling techniques like
  18. beam search, which motivates the separation of this logic from the single
  19. step output processor.
  20. This class is responsible for things such as correctly appending all new
  21. token ids to their sequence, detokenizing new token ids, truncating new
  22. output tokens after an eos token, and correctly handling the case where the
  23. number of new output tokens per sequence differs in a single batch.
  24. """
  25. def __init__(
  26. self,
  27. detokenizer: Detokenizer,
  28. scheduler: Scheduler,
  29. seq_counter: Iterable[int],
  30. get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
  31. stop_checker: StopChecker,
  32. ):
  33. self.detokenizer = detokenizer
  34. self.scheduler = scheduler
  35. self.seq_counter = seq_counter
  36. self.get_tokenizer_for_seq = get_tokenizer_for_seq
  37. self.stop_checker = stop_checker
  38. def process_outputs(self, sequence_group: SequenceGroup,
  39. outputs: List[SequenceGroupOutput]) -> None:
  40. """Append new tokens in the outputs to sequences in the sequence group.
  41. This only supports sequence groups of size 1. It supports greater than
  42. one new token per sequence.
  43. This applies logic like stop condition checking and detokenization,
  44. including freeing finished sequences. It also handles cases where there
  45. are tokens emitted after the EOS token.
  46. """
  47. seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
  48. assert seqs, "expected running sequences"
  49. assert len(seqs) == 1, (
  50. "Beam search not supported in multi-step decoding.")
  51. seq = seqs[0]
  52. # Since there's only one sequence per sequence group, we can take the
  53. # first sample.
  54. samples = [outputs[step].samples[0] for step in range(len(outputs))]
  55. # -1 means the output token is not valid (eg. due to spec decode
  56. # rejecting tokens).
  57. valid_samples = [
  58. sample for sample in samples if sample.output_token != -1
  59. ]
  60. assert valid_samples
  61. self._process_seq_outputs(seq, valid_samples,
  62. sequence_group.sampling_params)
  63. def _process_seq_outputs(self, seq: Sequence,
  64. valid_samples: List[SequenceOutput],
  65. sampling_params: SamplingParams) -> None:
  66. output_token_ids = [sample.output_token for sample in valid_samples]
  67. # Truncate to max_tokens if necessary.
  68. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
  69. len(output_token_ids))
  70. if remaining_tokens < 0:
  71. valid_samples = valid_samples[:remaining_tokens]
  72. output_token_ids = output_token_ids[:remaining_tokens]
  73. # Truncate any tokens after EOS. This is required as spec decode
  74. # generates a fixed number of tokens without evaluating stopping
  75. # conditions within the block. This can cause an eos token to be
  76. # unintentionally ignored.
  77. if not sampling_params.ignore_eos:
  78. eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
  79. # Avoiding .index calls as exception throwing in the happy path
  80. # is expensive.
  81. for i in range(len(output_token_ids)):
  82. if output_token_ids[i] == eos_token_id:
  83. output_token_ids = output_token_ids[:i + 1]
  84. valid_samples = valid_samples[:i + 1]
  85. break
  86. # Incrementally append tokens to the sequence, as if we had only one new
  87. # token.
  88. for output_token_id in output_token_ids:
  89. seq.append_token_id(
  90. token_id=output_token_id,
  91. # TODO emit logprobs in multi-step decoding.
  92. logprobs={output_token_id: Logprob(0.0)},
  93. )
  94. new_char_count = 0
  95. if sampling_params.detokenize:
  96. new_char_count = self.detokenizer.decode_sequence_inplace(
  97. seq, sampling_params)
  98. self.stop_checker.maybe_stop_sequence(
  99. seq,
  100. new_char_count=new_char_count,
  101. sampling_params=sampling_params)
  102. if seq.is_finished():
  103. break
  104. if seq.is_finished():
  105. self.scheduler.free_seq(seq)