multi_step.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from typing import Callable, List
  2. from transformers import PreTrainedTokenizer
  3. from aphrodite.engine.output_processor.interfaces import \
  4. SequenceGroupOutputProcessor
  5. from aphrodite.engine.output_processor.stop_checker import StopChecker
  6. from aphrodite.processing.scheduler import Scheduler
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.common.sequence import (Logprob, Sequence, SequenceGroup,
  9. SequenceGroupOutput, SequenceOutput,
  10. SequenceStatus)
  11. from aphrodite.transformers_utils.detokenizer import Detokenizer
  12. from aphrodite.common.utils import Counter
  13. class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
  14. """SequenceGroupOutputProcessor which handles logic related to
  15. detokenization and stopping conditions. It specializes to "multi-step
  16. decoding", where Aphrodite's worker may generate multiple tokens per
  17. invocation.
  18. This is currently mutually exclusive with advanced sampling techniques like
  19. beam search, which motivates the separation of this logic from the single
  20. step output processor.
  21. This class is responsible for things such as correctly appending all new
  22. token ids to their sequence, detokenizing new token ids, truncating new
  23. output tokens after an eos token, and correctly handling the case where the
  24. number of new output tokens per sequence differs in a single batch.
  25. """
  26. def __init__(
  27. self,
  28. detokenizer: Detokenizer,
  29. scheduler: Scheduler,
  30. seq_counter: Counter,
  31. get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
  32. stop_checker: StopChecker,
  33. ):
  34. self.detokenizer = detokenizer
  35. self.scheduler = scheduler
  36. self.seq_counter = seq_counter
  37. self.get_tokenizer_for_seq = get_tokenizer_for_seq
  38. self.stop_checker = stop_checker
  39. def process_outputs(self, sequence_group: SequenceGroup,
  40. outputs: List[SequenceGroupOutput]) -> None:
  41. """Append new tokens in the outputs to sequences in the sequence group.
  42. This only supports sequence groups of size 1. It supports greater than
  43. one new token per sequence.
  44. This applies logic like stop condition checking and detokenization,
  45. including freeing finished sequences. It also handles cases where there
  46. are tokens emitted after the EOS token.
  47. """
  48. seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
  49. assert seqs, "expected running sequences"
  50. assert len(seqs) == 1, (
  51. "Beam search not supported in multi-step decoding.")
  52. seq = seqs[0]
  53. # Since there's only one sequence per sequence group, we can take the
  54. # first sample.
  55. samples = [outputs[step].samples[0] for step in range(len(outputs))]
  56. # -1 means the output token is not valid (eg. due to spec decode
  57. # rejecting tokens).
  58. valid_samples = [
  59. sample for sample in samples if sample.output_token != -1
  60. ]
  61. assert valid_samples
  62. self._process_seq_outputs(seq, valid_samples,
  63. sequence_group.sampling_params)
  64. def _process_seq_outputs(self, seq: Sequence,
  65. valid_samples: List[SequenceOutput],
  66. sampling_params: SamplingParams) -> None:
  67. output_token_ids = [sample.output_token for sample in valid_samples]
  68. # Truncate to max_tokens if necessary.
  69. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
  70. len(output_token_ids))
  71. if remaining_tokens < 0:
  72. valid_samples = valid_samples[:remaining_tokens]
  73. output_token_ids = output_token_ids[:remaining_tokens]
  74. # Truncate any tokens after EOS. This is required as spec decode
  75. # generates a fixed number of tokens without evaluating stopping
  76. # conditions within the block. This can cause an eos token to be
  77. # unintentionally ignored.
  78. if not sampling_params.ignore_eos:
  79. eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
  80. # Avoiding .index calls as exception throwing in the happy path
  81. # is expensive.
  82. for i in range(len(output_token_ids)):
  83. if output_token_ids[i] == eos_token_id:
  84. output_token_ids = output_token_ids[:i + 1]
  85. valid_samples = valid_samples[:i + 1]
  86. break
  87. # Incrementally append tokens to the sequence, as if we had only one new
  88. # token.
  89. for output_token_id in output_token_ids:
  90. seq.append_token_id(
  91. token_id=output_token_id,
  92. # TODO emit logprobs in multi-step decoding.
  93. logprobs={output_token_id: Logprob(0.0)},
  94. )
  95. new_char_count = 0
  96. if sampling_params.detokenize:
  97. new_char_count = self.detokenizer.decode_sequence_inplace(
  98. seq, sampling_params)
  99. self.stop_checker.maybe_stop_sequence(
  100. seq,
  101. new_char_count=new_char_count,
  102. sampling_params=sampling_params)
  103. if seq.is_finished():
  104. break
  105. if seq.is_finished():
  106. self.scheduler.free_seq(seq)