single_step.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from typing import Iterable, List, Tuple, Union
  2. from aphrodite.common.config import SchedulerConfig
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  5. SequenceGroupOutput, SequenceOutput,
  6. SequenceStatus)
  7. from aphrodite.engine.output_processor.interfaces import \
  8. SequenceGroupOutputProcessor
  9. from aphrodite.engine.output_processor.stop_checker import StopChecker
  10. from aphrodite.processing.scheduler import Scheduler
  11. from aphrodite.transformers_utils.detokenizer import Detokenizer
  12. class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
  13. """SequenceGroupOutputProcessor which handles "output processing" logic,
  14. which happens after the model returns generated token ids and before
  15. scheduling of the next batch. Output processing logic includes
  16. detokenization, and determining if a sequence is finished (e.g. via max len
  17. or eos token).
  18. The SingleStepOutputProcessor is specialized to the case where the model
  19. emits at most a single token per invocation, which precludes configurations
  20. such as speculative decoding or multi-step decoding. This enables beam
  21. search sampling, which requires forking/finishing/freeing sequences in a way
  22. that is currently difficult to schedule multiple steps ahead of time.
  23. """
  24. def __init__(
  25. self,
  26. scheduler_config: SchedulerConfig,
  27. detokenizer: Detokenizer,
  28. scheduler: Scheduler,
  29. seq_counter: Iterable[int],
  30. stop_checker: StopChecker,
  31. ):
  32. self.scheduler_config = scheduler_config
  33. self.detokenizer = detokenizer
  34. self.scheduler = scheduler
  35. self.seq_counter = seq_counter
  36. self.stop_checker = stop_checker
  37. def process_outputs(self, sequence_group: SequenceGroup,
  38. outputs: List[SequenceGroupOutput]) -> None:
  39. """Append all new tokens to sequences in the sequence group. Fork any
  40. surviving beam candidates; free any unsurviving ones.
  41. Invokes detokenizer to detokenize new tokens, and also marks sequences
  42. as finished if they meet stop conditions.
  43. """
  44. assert (len(outputs) == 1
  45. ), f"{type(self)} does not support multiple outputs per step"
  46. return self._process_sequence_group_outputs(sequence_group, outputs[0])
  47. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  48. outputs: SequenceGroupOutput) -> None:
  49. # Process prompt logprobs
  50. prompt_logprobs = outputs.prompt_logprobs
  51. if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
  52. self.detokenizer.decode_prompt_logprobs_inplace(
  53. seq_group, prompt_logprobs)
  54. seq_group.prompt_logprobs = prompt_logprobs
  55. # Process samples
  56. samples = outputs.samples
  57. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  58. existing_finished_seqs = seq_group.get_finished_seqs()
  59. parent_child_dict = {
  60. parent_seq.seq_id: []
  61. for parent_seq in parent_seqs
  62. }
  63. for sample in samples:
  64. parent_child_dict[sample.parent_seq_id].append(sample)
  65. # List of (child, parent)
  66. child_seqs: List[Tuple[Sequence, Sequence]] = []
  67. # Process the child samples for each parent sequence
  68. for parent in parent_seqs:
  69. child_samples: List[SequenceOutput] = parent_child_dict[
  70. parent.seq_id]
  71. if len(child_samples) == 0:
  72. # This parent sequence has no children samples. Remove
  73. # the parent sequence from the sequence group since it will
  74. # not be used in the future iterations.
  75. parent.status = SequenceStatus.FINISHED_ABORTED
  76. seq_group.remove(parent.seq_id)
  77. self.scheduler.free_seq(parent)
  78. continue
  79. # Fork the parent sequence if there are multiple child samples.
  80. for child_sample in child_samples[:-1]:
  81. new_child_seq_id = next(self.seq_counter)
  82. child = parent.fork(new_child_seq_id)
  83. child.append_token_id(child_sample.output_token,
  84. child_sample.logprobs)
  85. child_seqs.append((child, parent))
  86. # Continue the parent sequence for the last child sample.
  87. # We reuse the parent sequence here to reduce redundant memory
  88. # copies, especially when using non-beam search sampling methods.
  89. last_child_sample = child_samples[-1]
  90. parent.append_token_id(last_child_sample.output_token,
  91. last_child_sample.logprobs)
  92. child_seqs.append((parent, parent))
  93. for seq, _ in child_seqs:
  94. if seq_group.sampling_params.detokenize:
  95. new_char_count = self.detokenizer.decode_sequence_inplace(
  96. seq, seq_group.sampling_params)
  97. else:
  98. new_char_count = 0
  99. self.stop_checker.maybe_stop_sequence(seq, new_char_count,
  100. seq_group.sampling_params)
  101. # Non-beam search case
  102. if not seq_group.sampling_params.use_beam_search:
  103. # For newly created child sequences, add them to the sequence group
  104. # and fork them in block manager if they are not finished.
  105. for seq, parent in child_seqs:
  106. if seq is not parent:
  107. seq_group.add(seq)
  108. if not seq.is_finished():
  109. self.scheduler.fork_seq(parent, seq)
  110. # Free the finished and selected parent sequences' memory in block
  111. # manager. Keep them in the sequence group as candidate output.
  112. # NOTE: we need to fork the new sequences before freeing the
  113. # old sequences.
  114. for seq, parent in child_seqs:
  115. if seq is parent and seq.is_finished():
  116. self.scheduler.free_seq(seq)
  117. return
  118. # Beam search case
  119. # Select the child sequences to keep in the sequence group.
  120. selected_child_seqs = []
  121. unselected_child_seqs = []
  122. beam_width = seq_group.sampling_params.best_of
  123. length_penalty = seq_group.sampling_params.length_penalty
  124. # Select the newly finished sequences with the highest scores
  125. # to replace existing finished sequences.
  126. # Tuple of (seq, parent, is_new)
  127. existing_finished_seqs = [(seq, None, False)
  128. for seq in existing_finished_seqs]
  129. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  130. if seq.is_finished()]
  131. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  132. # Sort the finished sequences by their scores.
  133. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  134. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  135. reverse=True)
  136. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  137. if is_new:
  138. # A newly generated child sequence finishes and has a high
  139. # score, so we will add it into the sequence group.
  140. selected_child_seqs.append((seq, parent))
  141. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  142. if is_new:
  143. # A newly generated child sequence finishes but has a low
  144. # score, so we will not add it into the sequence group.
  145. # Additionally, if this sequence is a continuation of a
  146. # parent sequence, we will need remove the parent sequence
  147. # from the sequence group.
  148. unselected_child_seqs.append((seq, parent))
  149. else:
  150. # An existing finished sequence has a low score, so we will
  151. # remove it from the sequence group.
  152. seq_group.remove(seq.seq_id)
  153. # select the top beam_width sequences from the running
  154. # sequences for the next iteration to continue the beam
  155. # search.
  156. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  157. if not seq.is_finished()]
  158. # Sort the running sequences by their scores.
  159. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  160. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  161. reverse=True)
  162. # Check if we can stop the beam search.
  163. if len(running_child_seqs) == 0:
  164. # No running sequences, stop the beam search.
  165. stop_beam_search = True
  166. elif len(all_finished_seqs) < beam_width:
  167. # Not enough finished sequences, continue the beam search.
  168. stop_beam_search = False
  169. else:
  170. # Check the early stopping criteria
  171. best_running_seq = running_child_seqs[0][0]
  172. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  173. stop_beam_search = self._check_beam_search_early_stopping(
  174. seq_group.sampling_params.early_stopping,
  175. seq_group.sampling_params, best_running_seq, current_worst_seq)
  176. if stop_beam_search:
  177. # Stop the beam search and remove all the running sequences from
  178. # the sequence group.
  179. unselected_child_seqs.extend(running_child_seqs)
  180. else:
  181. # Continue the beam search and select the top beam_width sequences
  182. # to continue the beam search.
  183. selected_child_seqs.extend(running_child_seqs[:beam_width])
  184. # The remaining running sequences will not be used in the next
  185. # iteration. Again, if these sequences are continuations of
  186. # parent sequences, we will need to remove the parent sequences
  187. # from the sequence group.
  188. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  189. # For newly created child sequences, add them to the sequence group
  190. # and fork them in block manager if they are not finished.
  191. for seq, parent in selected_child_seqs:
  192. if seq is not parent:
  193. seq_group.add(seq)
  194. if not seq.is_finished():
  195. self.scheduler.fork_seq(parent, seq)
  196. # Free the finished and selected parent sequences' memory in block
  197. # manager. Keep them in the sequence group as candidate output.
  198. for seq, parent in selected_child_seqs:
  199. if seq is parent and seq.is_finished():
  200. self.scheduler.free_seq(seq)
  201. # Remove the unselected parent sequences from the sequence group and
  202. # free their memory in block manager.
  203. for seq, parent in unselected_child_seqs:
  204. if seq is parent:
  205. # Remove the parent sequence if it is not selected for next
  206. # iteration
  207. seq_group.remove(seq.seq_id)
  208. self.scheduler.free_seq(seq)
  209. def _check_beam_search_early_stopping(
  210. self,
  211. early_stopping: Union[bool, str],
  212. sampling_params: SamplingParams,
  213. best_running_seq: Sequence,
  214. current_worst_seq: Sequence,
  215. ) -> bool:
  216. assert sampling_params.use_beam_search
  217. length_penalty = sampling_params.length_penalty
  218. if early_stopping is True:
  219. return True
  220. current_worst_score = current_worst_seq.get_beam_search_score(
  221. length_penalty=length_penalty,
  222. eos_token_id=current_worst_seq.eos_token_id)
  223. if early_stopping is False:
  224. highest_attainable_score = best_running_seq.get_beam_search_score(
  225. length_penalty=length_penalty,
  226. eos_token_id=best_running_seq.eos_token_id)
  227. else:
  228. assert early_stopping == "never"
  229. if length_penalty > 0.0:
  230. # If length_penalty > 0.0, beam search will prefer longer
  231. # sequences. The highest attainable score calculation is
  232. # based on the longest possible sequence length in this case.
  233. max_possible_length = max(
  234. best_running_seq.get_prompt_len() +
  235. sampling_params.max_tokens,
  236. self.scheduler_config.max_model_len)
  237. highest_attainable_score = (
  238. best_running_seq.get_beam_search_score(
  239. length_penalty=length_penalty,
  240. eos_token_id=best_running_seq.eos_token_id,
  241. seq_len=max_possible_length))
  242. else:
  243. # Otherwise, beam search will prefer shorter sequences. The
  244. # highest attainable score calculation is based on the current
  245. # sequence length.
  246. highest_attainable_score = (
  247. best_running_seq.get_beam_search_score(
  248. length_penalty=length_penalty,
  249. eos_token_id=best_running_seq.eos_token_id))
  250. return current_worst_score >= highest_attainable_score