single_step.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. from typing import Dict, List, Tuple, Union
  2. from aphrodite.common.config import SchedulerConfig
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  5. SequenceGroupOutput, SequenceOutput,
  6. SequenceStatus)
  7. from aphrodite.common.utils import Counter
  8. from aphrodite.engine.output_processor.interfaces import (
  9. SequenceGroupOutputProcessor)
  10. from aphrodite.engine.output_processor.stop_checker import StopChecker
  11. from aphrodite.processing.scheduler import Scheduler
  12. from aphrodite.transformers_utils.detokenizer import Detokenizer
  13. class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
  14. """SequenceGroupOutputProcessor which handles "output processing" logic,
  15. which happens after the model returns generated token ids and before
  16. scheduling of the next batch. Output processing logic includes
  17. detokenization, and determining if a sequence is finished (e.g. via max len
  18. or eos token).
  19. The SingleStepOutputProcessor is specialized to the case where the model
  20. emits at most a single token per invocation, which precludes configurations
  21. such as speculative decoding or multi-step decoding. This enables beam
  22. search sampling, which requires forking/finishing/freeing sequences in a way
  23. that is currently difficult to schedule multiple steps ahead of time.
  24. """
  25. def __init__(
  26. self,
  27. scheduler_config: SchedulerConfig,
  28. detokenizer: Detokenizer,
  29. scheduler: List[Scheduler],
  30. seq_counter: Counter,
  31. stop_checker: StopChecker,
  32. ):
  33. self.scheduler_config = scheduler_config
  34. self.detokenizer = detokenizer
  35. self.scheduler = scheduler
  36. self.seq_counter = seq_counter
  37. self.stop_checker = stop_checker
  38. def process_outputs(self, sequence_group: SequenceGroup,
  39. outputs: List[SequenceGroupOutput]) -> None:
  40. """Append all new tokens to sequences in the sequence group. Fork any
  41. surviving beam candidates; free any unsurviving ones.
  42. Invokes detokenizer to detokenize new tokens, and also marks sequences
  43. as finished if they meet stop conditions.
  44. """
  45. assert (len(outputs) == 1
  46. ), f"{type(self)} does not support multiple outputs per step"
  47. return self._process_sequence_group_outputs(sequence_group, outputs[0])
  48. def process_prompt_logprob(self, seq_group: SequenceGroup,
  49. outputs: List[SequenceGroupOutput]) -> None:
  50. assert len(outputs) == 1, ("Single step should only has 1 output.")
  51. output = outputs[0]
  52. prompt_logprobs = output.prompt_logprobs
  53. # If this is the first (or only) "chunk" of the prefill, we need
  54. # to prepend None to the list of prompt logprobs. The reason for this
  55. # is that for N prompt tokens, the Sampler will generate N-1 total
  56. # prompt logprobs during prefill since the token at idx 0 will not
  57. # have a logprob associated with it.
  58. if prompt_logprobs is not None:
  59. if not seq_group.prompt_logprobs:
  60. prompt_logprobs = [None] + prompt_logprobs
  61. seq_group.prompt_logprobs = []
  62. if seq_group.sampling_params.detokenize and self.detokenizer:
  63. self.detokenizer.decode_prompt_logprobs_inplace(
  64. seq_group,
  65. prompt_logprobs,
  66. position_offset=len(seq_group.prompt_logprobs))
  67. seq_group.prompt_logprobs.extend(prompt_logprobs)
  68. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  69. outputs: SequenceGroupOutput) -> None:
  70. sampling_params = seq_group.sampling_params
  71. if sampling_params.n == 1 and not sampling_params.use_beam_search:
  72. # only have one output sample
  73. sample = outputs.samples[0]
  74. # only have one sequence
  75. seq = seq_group.seqs[0]
  76. seq.append_token_id(sample.output_token, sample.logprobs)
  77. if sampling_params.detokenize and self.detokenizer:
  78. new_char_count = self.detokenizer.decode_sequence_inplace(
  79. seq, sampling_params)
  80. else:
  81. new_char_count = 0
  82. self.stop_checker.maybe_stop_sequence(
  83. seq,
  84. new_char_count,
  85. sampling_params,
  86. lora_req=seq_group.lora_request,
  87. )
  88. if seq.is_finished():
  89. for scheduler in self.scheduler:
  90. scheduler.free_seq(seq)
  91. return
  92. # Process samples
  93. samples = outputs.samples
  94. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  95. existing_finished_seqs = seq_group.get_finished_seqs()
  96. parent_child_dict: Dict[int, List[SequenceOutput]] = {
  97. parent_seq.seq_id: []
  98. for parent_seq in parent_seqs
  99. }
  100. for sample in samples:
  101. # Guard against a KeyError which can occur if the request was
  102. # aborted while the output was generated
  103. if (child_list :=
  104. parent_child_dict.get(sample.parent_seq_id)) is not None:
  105. child_list.append(sample)
  106. # List of (child, parent)
  107. child_seqs: List[Tuple[Sequence, Sequence]] = []
  108. # Process the child samples for each parent sequence
  109. for parent in parent_seqs:
  110. child_samples: List[SequenceOutput] = parent_child_dict[
  111. parent.seq_id]
  112. if len(child_samples) == 0:
  113. # This parent sequence has no children samples. Remove
  114. # the parent sequence from the sequence group since it will
  115. # not be used in the future iterations.
  116. parent.status = SequenceStatus.FINISHED_ABORTED
  117. seq_group.remove(parent.seq_id)
  118. for scheduler in self.scheduler:
  119. scheduler.free_seq(parent)
  120. continue
  121. # Fork the parent sequence if there are multiple child samples.
  122. for child_sample in child_samples[:-1]:
  123. new_child_seq_id: int = next(self.seq_counter)
  124. child = parent.fork(new_child_seq_id)
  125. child.append_token_id(child_sample.output_token,
  126. child_sample.logprobs)
  127. child_seqs.append((child, parent))
  128. # Continue the parent sequence for the last child sample.
  129. # We reuse the parent sequence here to reduce redundant memory
  130. # copies, especially when using non-beam search sampling methods.
  131. last_child_sample = child_samples[-1]
  132. parent.append_token_id(last_child_sample.output_token,
  133. last_child_sample.logprobs)
  134. child_seqs.append((parent, parent))
  135. for seq, _ in child_seqs:
  136. if sampling_params.detokenize and self.detokenizer:
  137. new_char_count = self.detokenizer.decode_sequence_inplace(
  138. seq, sampling_params)
  139. else:
  140. new_char_count = 0
  141. self.stop_checker.maybe_stop_sequence(
  142. seq,
  143. new_char_count,
  144. sampling_params,
  145. lora_req=seq_group.lora_request,
  146. )
  147. # Non-beam search case
  148. if not sampling_params.use_beam_search:
  149. # For newly created child sequences, add them to the sequence group
  150. # and fork them in block manager if they are not finished.
  151. for seq, parent in child_seqs:
  152. if seq is not parent:
  153. seq_group.add(seq)
  154. if not seq.is_finished():
  155. for scheduler in self.scheduler:
  156. scheduler.fork_seq(parent, seq)
  157. # Free the finished and selected parent sequences' memory in block
  158. # manager. Keep them in the sequence group as candidate output.
  159. # NOTE: we need to fork the new sequences before freeing the
  160. # old sequences.
  161. for seq, parent in child_seqs:
  162. if seq is parent and seq.is_finished():
  163. for scheduler in self.scheduler:
  164. scheduler.free_seq(seq)
  165. return
  166. # Beam search case
  167. # Select the child sequences to keep in the sequence group.
  168. selected_child_seqs = []
  169. unselected_child_seqs = []
  170. beam_width = sampling_params.best_of
  171. length_penalty = sampling_params.length_penalty
  172. # Select the newly finished sequences with the highest scores
  173. # to replace existing finished sequences.
  174. # Tuple of (seq, parent, is_new)
  175. existing_finished_seqs = [(seq, None, False)
  176. for seq in existing_finished_seqs]
  177. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  178. if seq.is_finished()]
  179. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  180. # Sort the finished sequences by their scores.
  181. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  182. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  183. reverse=True)
  184. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  185. if is_new:
  186. # A newly generated child sequence finishes and has a high
  187. # score, so we will add it into the sequence group.
  188. selected_child_seqs.append((seq, parent))
  189. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  190. if is_new:
  191. # A newly generated child sequence finishes but has a low
  192. # score, so we will not add it into the sequence group.
  193. # Additionally, if this sequence is a continuation of a
  194. # parent sequence, we will need remove the parent sequence
  195. # from the sequence group.
  196. unselected_child_seqs.append((seq, parent))
  197. else:
  198. # An existing finished sequence has a low score, so we will
  199. # remove it from the sequence group.
  200. seq_group.remove(seq.seq_id)
  201. # select the top beam_width sequences from the running
  202. # sequences for the next iteration to continue the beam
  203. # search.
  204. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  205. if not seq.is_finished()]
  206. # Sort the running sequences by their scores.
  207. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  208. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  209. reverse=True)
  210. # Check if we can stop the beam search.
  211. if len(running_child_seqs) == 0:
  212. # No running sequences, stop the beam search.
  213. stop_beam_search = True
  214. elif len(all_finished_seqs) < beam_width:
  215. # Not enough finished sequences, continue the beam search.
  216. stop_beam_search = False
  217. else:
  218. # Check the early stopping criteria
  219. best_running_seq = running_child_seqs[0][0]
  220. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  221. stop_beam_search = self._check_beam_search_early_stopping(
  222. sampling_params.early_stopping, sampling_params,
  223. best_running_seq, current_worst_seq)
  224. if stop_beam_search:
  225. # Stop the beam search and remove all the running sequences from
  226. # the sequence group.
  227. unselected_child_seqs.extend(running_child_seqs)
  228. else:
  229. # Continue the beam search and select the top beam_width sequences
  230. # to continue the beam search.
  231. selected_child_seqs.extend(running_child_seqs[:beam_width])
  232. # The remaining running sequences will not be used in the next
  233. # iteration. Again, if these sequences are continuations of
  234. # parent sequences, we will need to remove the parent sequences
  235. # from the sequence group.
  236. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  237. # For newly created child sequences, add them to the sequence group
  238. # and fork them in block manager if they are not finished.
  239. for seq, parent in selected_child_seqs:
  240. if seq is not parent:
  241. seq_group.add(seq)
  242. if not seq.is_finished():
  243. for scheduler in self.scheduler:
  244. scheduler.fork_seq(parent, seq)
  245. # Free the finished and selected parent sequences' memory in block
  246. # manager. Keep them in the sequence group as candidate output.
  247. for seq, parent in selected_child_seqs:
  248. if seq is parent and seq.is_finished():
  249. for scheduler in self.scheduler:
  250. scheduler.free_seq(seq)
  251. # Remove the unselected parent sequences from the sequence group and
  252. # free their memory in block manager.
  253. for seq, parent in unselected_child_seqs:
  254. if seq is parent:
  255. # Remove the parent sequence if it is not selected for next
  256. # iteration
  257. seq_group.remove(seq.seq_id)
  258. for scheduler in self.scheduler:
  259. scheduler.free_seq(seq)
  260. def _check_beam_search_early_stopping(
  261. self,
  262. early_stopping: Union[bool, str],
  263. sampling_params: SamplingParams,
  264. best_running_seq: Sequence,
  265. current_worst_seq: Sequence,
  266. ) -> bool:
  267. assert sampling_params.use_beam_search
  268. length_penalty = sampling_params.length_penalty
  269. if early_stopping is True:
  270. return True
  271. current_worst_score = current_worst_seq.get_beam_search_score(
  272. length_penalty=length_penalty,
  273. eos_token_id=current_worst_seq.eos_token_id)
  274. if early_stopping is False:
  275. highest_attainable_score = best_running_seq.get_beam_search_score(
  276. length_penalty=length_penalty,
  277. eos_token_id=best_running_seq.eos_token_id)
  278. else:
  279. assert early_stopping == "never"
  280. if length_penalty > 0.0:
  281. # If length_penalty > 0.0, beam search will prefer longer
  282. # sequences. The highest attainable score calculation is
  283. # based on the longest possible sequence length in this case.
  284. max_possible_length = max(
  285. best_running_seq.get_prompt_len() +
  286. sampling_params.max_tokens,
  287. self.scheduler_config.max_model_len)
  288. highest_attainable_score = (
  289. best_running_seq.get_beam_search_score(
  290. length_penalty=length_penalty,
  291. eos_token_id=best_running_seq.eos_token_id,
  292. seq_len=max_possible_length))
  293. else:
  294. # Otherwise, beam search will prefer shorter sequences. The
  295. # highest attainable score calculation is based on the current
  296. # sequence length.
  297. highest_attainable_score = (
  298. best_running_seq.get_beam_search_score(
  299. length_penalty=length_penalty,
  300. eos_token_id=best_running_seq.eos_token_id))
  301. return current_worst_score >= highest_attainable_score