single_step.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from typing import Iterable, List, Tuple, Union
  2. from aphrodite.common.config import SchedulerConfig
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  5. SequenceGroupOutput, SequenceOutput,
  6. SequenceStatus)
  7. from aphrodite.engine.output_processor.interfaces import \
  8. SequenceGroupOutputProcessor
  9. from aphrodite.engine.output_processor.stop_checker import StopChecker
  10. from aphrodite.processing.scheduler import Scheduler
  11. from aphrodite.transformers_utils.detokenizer import Detokenizer
  12. class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
  13. """SequenceGroupOutputProcessor which handles "output processing" logic,
  14. which happens after the model returns generated token ids and before
  15. scheduling of the next batch. Output processing logic includes
  16. detokenization, and determining if a sequence is finished (e.g. via max len
  17. or eos token).
  18. The SingleStepOutputProcessor is specialized to the case where the model
  19. emits at most a single token per invocation, which precludes configurations
  20. such as speculative decoding or multi-step decoding. This enables beam
  21. search sampling, which requires forking/finishing/freeing sequences in a way
  22. that is currently difficult to schedule multiple steps ahead of time.
  23. """
  24. def __init__(
  25. self,
  26. scheduler_config: SchedulerConfig,
  27. detokenizer: Detokenizer,
  28. scheduler: Scheduler,
  29. seq_counter: Iterable[int],
  30. stop_checker: StopChecker,
  31. ):
  32. self.scheduler_config = scheduler_config
  33. self.detokenizer = detokenizer
  34. self.scheduler = scheduler
  35. self.seq_counter = seq_counter
  36. self.stop_checker = stop_checker
  37. def process_outputs(self, sequence_group: SequenceGroup,
  38. outputs: List[SequenceGroupOutput]) -> None:
  39. """Append all new tokens to sequences in the sequence group. Fork any
  40. surviving beam candidates; free any unsurviving ones.
  41. Invokes detokenizer to detokenize new tokens, and also marks sequences
  42. as finished if they meet stop conditions.
  43. """
  44. assert (len(outputs) == 1
  45. ), f"{type(self)} does not support multiple outputs per step"
  46. return self._process_sequence_group_outputs(sequence_group, outputs[0])
  47. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  48. outputs: SequenceGroupOutput) -> None:
  49. # Process prompt logprobs
  50. prompt_logprobs = outputs.prompt_logprobs
  51. if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
  52. self.detokenizer.decode_prompt_logprobs_inplace(
  53. seq_group, prompt_logprobs)
  54. seq_group.prompt_logprobs = prompt_logprobs
  55. # Process samples
  56. samples = outputs.samples
  57. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  58. existing_finished_seqs = seq_group.get_finished_seqs()
  59. parent_child_dict = {
  60. parent_seq.seq_id: []
  61. for parent_seq in parent_seqs
  62. }
  63. for sample in samples:
  64. parent_child_dict[sample.parent_seq_id].append(sample)
  65. # List of (child, parent)
  66. child_seqs: List[Tuple[Sequence, Sequence]] = []
  67. # In tree parallel decoding, all sequences within a sequence group
  68. # are always inferred simultaneously, which results in the generation
  69. # of some extra tokens that need to be appended.
  70. root_seq = seq_group.find(seq_group.root_seq_id)
  71. for _ in range(len(seq_group.seqs_dict) - len(parent_seqs)):
  72. root_seq._append_tokens_to_blocks([0])
  73. # Process the child samples for each parent sequence
  74. for parent in parent_seqs:
  75. child_samples: List[SequenceOutput] = parent_child_dict[
  76. parent.seq_id]
  77. if len(child_samples) == 0:
  78. # This parent sequence has no children samples. Remove
  79. # the parent sequence from the sequence group since it will
  80. # not be used in the future iterations.
  81. parent.status = SequenceStatus.FINISHED_ABORTED
  82. seq_group.remove(parent.seq_id)
  83. self.scheduler.free_seq(parent)
  84. continue
  85. # Fork the parent sequence if there are multiple child samples.
  86. for child_sample in child_samples[:-1]:
  87. new_child_seq_id = next(self.seq_counter)
  88. child = parent.fork(new_child_seq_id)
  89. child.append_token_id(child_sample.output_token,
  90. child_sample.logprobs)
  91. child_seqs.append((child, parent))
  92. # Continue the parent sequence for the last child sample.
  93. # We reuse the parent sequence here to reduce redundant memory
  94. # copies, especially when using non-beam search sampling methods.
  95. last_child_sample = child_samples[-1]
  96. parent.append_token_id(last_child_sample.output_token,
  97. last_child_sample.logprobs)
  98. child_seqs.append((parent, parent))
  99. for seq, _ in child_seqs:
  100. if seq_group.sampling_params.detokenize:
  101. new_char_count = self.detokenizer.decode_sequence_inplace(
  102. seq, seq_group.sampling_params)
  103. else:
  104. new_char_count = 0
  105. self.stop_checker.maybe_stop_sequence(seq, new_char_count,
  106. seq_group.sampling_params)
  107. # Non-beam search case
  108. if not seq_group.sampling_params.use_beam_search:
  109. # For newly created child sequences, add them to the sequence group
  110. # and fork them in block manager if they are not finished.
  111. for seq, parent in child_seqs:
  112. if seq is not parent:
  113. seq_group.add(seq)
  114. if not seq.is_finished():
  115. self.scheduler.fork_seq(parent, seq)
  116. # Free the finished and selected parent sequences' memory in block
  117. # manager. Keep them in the sequence group as candidate output.
  118. # NOTE: we need to fork the new sequences before freeing the
  119. # old sequences.
  120. for seq, parent in child_seqs:
  121. if seq is parent and seq.is_finished():
  122. self.scheduler.free_seq(seq)
  123. return
  124. # Beam search case
  125. # Select the child sequences to keep in the sequence group.
  126. selected_child_seqs = []
  127. unselected_child_seqs = []
  128. beam_width = seq_group.sampling_params.best_of
  129. length_penalty = seq_group.sampling_params.length_penalty
  130. # Select the newly finished sequences with the highest scores
  131. # to replace existing finished sequences.
  132. # Tuple of (seq, parent, is_new)
  133. existing_finished_seqs = [(seq, None, False)
  134. for seq in existing_finished_seqs]
  135. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  136. if seq.is_finished()]
  137. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  138. # Sort the finished sequences by their scores.
  139. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  140. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  141. reverse=True)
  142. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  143. if is_new:
  144. # A newly generated child sequence finishes and has a high
  145. # score, so we will add it into the sequence group.
  146. selected_child_seqs.append((seq, parent))
  147. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  148. if is_new:
  149. # A newly generated child sequence finishes but has a low
  150. # score, so we will not add it into the sequence group.
  151. # Additionally, if this sequence is a continuation of a
  152. # parent sequence, we will need remove the parent sequence
  153. # from the sequence group.
  154. unselected_child_seqs.append((seq, parent))
  155. else:
  156. # An existing finished sequence has a low score, so we will
  157. # remove it from the sequence group.
  158. seq_group.remove(seq.seq_id)
  159. # select the top beam_width sequences from the running
  160. # sequences for the next iteration to continue the beam
  161. # search.
  162. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  163. if not seq.is_finished()]
  164. # Sort the running sequences by their scores.
  165. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  166. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  167. reverse=True)
  168. # Check if we can stop the beam search.
  169. if len(running_child_seqs) == 0:
  170. # No running sequences, stop the beam search.
  171. stop_beam_search = True
  172. elif len(all_finished_seqs) < beam_width:
  173. # Not enough finished sequences, continue the beam search.
  174. stop_beam_search = False
  175. else:
  176. # Check the early stopping criteria
  177. best_running_seq = running_child_seqs[0][0]
  178. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  179. stop_beam_search = self._check_beam_search_early_stopping(
  180. seq_group.sampling_params.early_stopping,
  181. seq_group.sampling_params, best_running_seq, current_worst_seq)
  182. if stop_beam_search:
  183. # Stop the beam search and remove all the running sequences from
  184. # the sequence group.
  185. unselected_child_seqs.extend(running_child_seqs)
  186. else:
  187. # Continue the beam search and select the top beam_width sequences
  188. # to continue the beam search.
  189. selected_child_seqs.extend(running_child_seqs[:beam_width])
  190. # The remaining running sequences will not be used in the next
  191. # iteration. Again, if these sequences are continuations of
  192. # parent sequences, we will need to remove the parent sequences
  193. # from the sequence group.
  194. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  195. # For newly created child sequences, add them to the sequence group
  196. # and fork them in block manager if they are not finished.
  197. for seq, parent in selected_child_seqs:
  198. if seq is not parent:
  199. seq_group.add(seq)
  200. if not seq.is_finished():
  201. self.scheduler.fork_seq(parent, seq)
  202. # Free the finished and selected parent sequences' memory in block
  203. # manager. Keep them in the sequence group as candidate output.
  204. for seq, parent in selected_child_seqs:
  205. if seq is parent and seq.is_finished():
  206. self.scheduler.free_seq(seq)
  207. # Remove the unselected parent sequences from the sequence group and
  208. # free their memory in block manager.
  209. for seq, parent in unselected_child_seqs:
  210. if seq is parent:
  211. # Remove the parent sequence if it is not selected for next
  212. # iteration
  213. seq_group.remove(seq.seq_id)
  214. self.scheduler.free_seq(seq)
  215. def _check_beam_search_early_stopping(
  216. self,
  217. early_stopping: Union[bool, str],
  218. sampling_params: SamplingParams,
  219. best_running_seq: Sequence,
  220. current_worst_seq: Sequence,
  221. ) -> bool:
  222. assert sampling_params.use_beam_search
  223. length_penalty = sampling_params.length_penalty
  224. if early_stopping is True:
  225. return True
  226. current_worst_score = current_worst_seq.get_beam_search_score(
  227. length_penalty=length_penalty,
  228. eos_token_id=current_worst_seq.eos_token_id)
  229. if early_stopping is False:
  230. highest_attainable_score = best_running_seq.get_beam_search_score(
  231. length_penalty=length_penalty,
  232. eos_token_id=best_running_seq.eos_token_id)
  233. else:
  234. assert early_stopping == "never"
  235. if length_penalty > 0.0:
  236. # If length_penalty > 0.0, beam search will prefer longer
  237. # sequences. The highest attainable score calculation is
  238. # based on the longest possible sequence length in this case.
  239. max_possible_length = max(
  240. best_running_seq.get_prompt_len() +
  241. sampling_params.max_tokens,
  242. self.scheduler_config.max_model_len)
  243. highest_attainable_score = (
  244. best_running_seq.get_beam_search_score(
  245. length_penalty=length_penalty,
  246. eos_token_id=best_running_seq.eos_token_id,
  247. seq_len=max_possible_length))
  248. else:
  249. # Otherwise, beam search will prefer shorter sequences. The
  250. # highest attainable score calculation is based on the current
  251. # sequence length.
  252. highest_attainable_score = (
  253. best_running_seq.get_beam_search_score(
  254. length_penalty=length_penalty,
  255. eos_token_id=best_running_seq.eos_token_id))
  256. return current_worst_score >= highest_attainable_score