single_step.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from typing import Dict, List, Tuple, Union
  2. from aphrodite.common.config import SchedulerConfig
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  5. SequenceGroupOutput, SequenceOutput,
  6. SequenceStatus)
  7. from aphrodite.common.utils import Counter
  8. from aphrodite.engine.output_processor.interfaces import \
  9. SequenceGroupOutputProcessor
  10. from aphrodite.engine.output_processor.stop_checker import StopChecker
  11. from aphrodite.processing.scheduler import Scheduler
  12. from aphrodite.transformers_utils.detokenizer import Detokenizer
  13. class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
  14. """SequenceGroupOutputProcessor which handles "output processing" logic,
  15. which happens after the model returns generated token ids and before
  16. scheduling of the next batch. Output processing logic includes
  17. detokenization, and determining if a sequence is finished (e.g. via max len
  18. or eos token).
  19. The SingleStepOutputProcessor is specialized to the case where the model
  20. emits at most a single token per invocation, which precludes configurations
  21. such as speculative decoding or multi-step decoding. This enables beam
  22. search sampling, which requires forking/finishing/freeing sequences in a way
  23. that is currently difficult to schedule multiple steps ahead of time.
  24. """
  25. def __init__(
  26. self,
  27. scheduler_config: SchedulerConfig,
  28. detokenizer: Detokenizer,
  29. scheduler: Scheduler,
  30. seq_counter: Counter,
  31. stop_checker: StopChecker,
  32. ):
  33. self.scheduler_config = scheduler_config
  34. self.detokenizer = detokenizer
  35. self.scheduler = scheduler
  36. self.seq_counter = seq_counter
  37. self.stop_checker = stop_checker
  38. def process_outputs(self, sequence_group: SequenceGroup,
  39. outputs: List[SequenceGroupOutput]) -> None:
  40. """Append all new tokens to sequences in the sequence group. Fork any
  41. surviving beam candidates; free any unsurviving ones.
  42. Invokes detokenizer to detokenize new tokens, and also marks sequences
  43. as finished if they meet stop conditions.
  44. """
  45. assert (len(outputs) == 1
  46. ), f"{type(self)} does not support multiple outputs per step"
  47. return self._process_sequence_group_outputs(sequence_group, outputs[0])
  48. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  49. outputs: SequenceGroupOutput) -> None:
  50. # Process prompt logprobs
  51. prompt_logprobs = outputs.prompt_logprobs
  52. if prompt_logprobs is not None and \
  53. seq_group.sampling_params.detokenize and self.detokenizer:
  54. self.detokenizer.decode_prompt_logprobs_inplace(
  55. seq_group, prompt_logprobs)
  56. seq_group.prompt_logprobs = prompt_logprobs
  57. # Process samples
  58. samples = outputs.samples
  59. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  60. existing_finished_seqs = seq_group.get_finished_seqs()
  61. parent_child_dict: Dict[int, List[SequenceOutput]] = {
  62. parent_seq.seq_id: []
  63. for parent_seq in parent_seqs
  64. }
  65. for sample in samples:
  66. parent_child_dict[sample.parent_seq_id].append(sample)
  67. # List of (child, parent)
  68. child_seqs: List[Tuple[Sequence, Sequence]] = []
  69. # Process the child samples for each parent sequence
  70. for parent in parent_seqs:
  71. child_samples: List[SequenceOutput] = parent_child_dict[
  72. parent.seq_id]
  73. if len(child_samples) == 0:
  74. # This parent sequence has no children samples. Remove
  75. # the parent sequence from the sequence group since it will
  76. # not be used in the future iterations.
  77. parent.status = SequenceStatus.FINISHED_ABORTED
  78. seq_group.remove(parent.seq_id)
  79. self.scheduler.free_seq(parent)
  80. continue
  81. # Fork the parent sequence if there are multiple child samples.
  82. for child_sample in child_samples[:-1]:
  83. new_child_seq_id: int = next(self.seq_counter)
  84. child = parent.fork(new_child_seq_id)
  85. child.append_token_id(child_sample.output_token,
  86. child_sample.logprobs)
  87. child_seqs.append((child, parent))
  88. # Continue the parent sequence for the last child sample.
  89. # We reuse the parent sequence here to reduce redundant memory
  90. # copies, especially when using non-beam search sampling methods.
  91. last_child_sample = child_samples[-1]
  92. parent.append_token_id(last_child_sample.output_token,
  93. last_child_sample.logprobs)
  94. child_seqs.append((parent, parent))
  95. for seq, _ in child_seqs:
  96. if seq_group.sampling_params.detokenize and self.detokenizer:
  97. new_char_count = self.detokenizer.decode_sequence_inplace(
  98. seq, seq_group.sampling_params)
  99. else:
  100. new_char_count = 0
  101. self.stop_checker.maybe_stop_sequence(seq, new_char_count,
  102. seq_group.sampling_params)
  103. # Non-beam search case
  104. if not seq_group.sampling_params.use_beam_search:
  105. # For newly created child sequences, add them to the sequence group
  106. # and fork them in block manager if they are not finished.
  107. for seq, parent in child_seqs:
  108. if seq is not parent:
  109. seq_group.add(seq)
  110. if not seq.is_finished():
  111. self.scheduler.fork_seq(parent, seq)
  112. # Free the finished and selected parent sequences' memory in block
  113. # manager. Keep them in the sequence group as candidate output.
  114. # NOTE: we need to fork the new sequences before freeing the
  115. # old sequences.
  116. for seq, parent in child_seqs:
  117. if seq is parent and seq.is_finished():
  118. self.scheduler.free_seq(seq)
  119. return
  120. # Beam search case
  121. # Select the child sequences to keep in the sequence group.
  122. selected_child_seqs = []
  123. unselected_child_seqs = []
  124. beam_width = seq_group.sampling_params.best_of
  125. length_penalty = seq_group.sampling_params.length_penalty
  126. # Select the newly finished sequences with the highest scores
  127. # to replace existing finished sequences.
  128. # Tuple of (seq, parent, is_new)
  129. existing_finished_seqs = [(seq, None, False)
  130. for seq in existing_finished_seqs]
  131. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  132. if seq.is_finished()]
  133. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  134. # Sort the finished sequences by their scores.
  135. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  136. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  137. reverse=True)
  138. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  139. if is_new:
  140. # A newly generated child sequence finishes and has a high
  141. # score, so we will add it into the sequence group.
  142. selected_child_seqs.append((seq, parent))
  143. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  144. if is_new:
  145. # A newly generated child sequence finishes but has a low
  146. # score, so we will not add it into the sequence group.
  147. # Additionally, if this sequence is a continuation of a
  148. # parent sequence, we will need remove the parent sequence
  149. # from the sequence group.
  150. unselected_child_seqs.append((seq, parent))
  151. else:
  152. # An existing finished sequence has a low score, so we will
  153. # remove it from the sequence group.
  154. seq_group.remove(seq.seq_id)
  155. # select the top beam_width sequences from the running
  156. # sequences for the next iteration to continue the beam
  157. # search.
  158. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  159. if not seq.is_finished()]
  160. # Sort the running sequences by their scores.
  161. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  162. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  163. reverse=True)
  164. # Check if we can stop the beam search.
  165. if len(running_child_seqs) == 0:
  166. # No running sequences, stop the beam search.
  167. stop_beam_search = True
  168. elif len(all_finished_seqs) < beam_width:
  169. # Not enough finished sequences, continue the beam search.
  170. stop_beam_search = False
  171. else:
  172. # Check the early stopping criteria
  173. best_running_seq = running_child_seqs[0][0]
  174. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  175. stop_beam_search = self._check_beam_search_early_stopping(
  176. seq_group.sampling_params.early_stopping,
  177. seq_group.sampling_params, best_running_seq, current_worst_seq)
  178. if stop_beam_search:
  179. # Stop the beam search and remove all the running sequences from
  180. # the sequence group.
  181. unselected_child_seqs.extend(running_child_seqs)
  182. else:
  183. # Continue the beam search and select the top beam_width sequences
  184. # to continue the beam search.
  185. selected_child_seqs.extend(running_child_seqs[:beam_width])
  186. # The remaining running sequences will not be used in the next
  187. # iteration. Again, if these sequences are continuations of
  188. # parent sequences, we will need to remove the parent sequences
  189. # from the sequence group.
  190. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  191. # For newly created child sequences, add them to the sequence group
  192. # and fork them in block manager if they are not finished.
  193. for seq, parent in selected_child_seqs:
  194. if seq is not parent:
  195. seq_group.add(seq)
  196. if not seq.is_finished():
  197. self.scheduler.fork_seq(parent, seq)
  198. # Free the finished and selected parent sequences' memory in block
  199. # manager. Keep them in the sequence group as candidate output.
  200. for seq, parent in selected_child_seqs:
  201. if seq is parent and seq.is_finished():
  202. self.scheduler.free_seq(seq)
  203. # Remove the unselected parent sequences from the sequence group and
  204. # free their memory in block manager.
  205. for seq, parent in unselected_child_seqs:
  206. if seq is parent:
  207. # Remove the parent sequence if it is not selected for next
  208. # iteration
  209. seq_group.remove(seq.seq_id)
  210. self.scheduler.free_seq(seq)
  211. def _check_beam_search_early_stopping(
  212. self,
  213. early_stopping: Union[bool, str],
  214. sampling_params: SamplingParams,
  215. best_running_seq: Sequence,
  216. current_worst_seq: Sequence,
  217. ) -> bool:
  218. assert sampling_params.use_beam_search
  219. length_penalty = sampling_params.length_penalty
  220. if early_stopping is True:
  221. return True
  222. current_worst_score = current_worst_seq.get_beam_search_score(
  223. length_penalty=length_penalty,
  224. eos_token_id=current_worst_seq.eos_token_id)
  225. if early_stopping is False:
  226. highest_attainable_score = best_running_seq.get_beam_search_score(
  227. length_penalty=length_penalty,
  228. eos_token_id=best_running_seq.eos_token_id)
  229. else:
  230. assert early_stopping == "never"
  231. if length_penalty > 0.0:
  232. # If length_penalty > 0.0, beam search will prefer longer
  233. # sequences. The highest attainable score calculation is
  234. # based on the longest possible sequence length in this case.
  235. max_possible_length = max(
  236. best_running_seq.get_prompt_len() +
  237. sampling_params.max_tokens,
  238. self.scheduler_config.max_model_len)
  239. highest_attainable_score = (
  240. best_running_seq.get_beam_search_score(
  241. length_penalty=length_penalty,
  242. eos_token_id=best_running_seq.eos_token_id,
  243. seq_len=max_possible_length))
  244. else:
  245. # Otherwise, beam search will prefer shorter sequences. The
  246. # highest attainable score calculation is based on the current
  247. # sequence length.
  248. highest_attainable_score = (
  249. best_running_seq.get_beam_search_score(
  250. length_penalty=length_penalty,
  251. eos_token_id=best_running_seq.eos_token_id))
  252. return current_worst_score >= highest_attainable_score