single_step.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from typing import Dict, List, Tuple, Union
  2. from aphrodite.common.config import SchedulerConfig
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  5. SequenceGroupOutput, SequenceOutput,
  6. SequenceStatus)
  7. from aphrodite.common.utils import Counter
  8. from aphrodite.engine.output_processor.interfaces import (
  9. SequenceGroupOutputProcessor)
  10. from aphrodite.engine.output_processor.stop_checker import StopChecker
  11. from aphrodite.processing.scheduler import Scheduler
  12. from aphrodite.transformers_utils.detokenizer import Detokenizer
  13. def single_step_process_prompt_logprob(
  14. sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
  15. output: SequenceGroupOutput) -> None:
  16. """Process prompt logprobs associated with the :class:`SequenceGroupOutput`
  17. for a given step.
  18. Do nothing if the output has no prompt logprobs.
  19. Account for the fact that transformers do not compute first-token logprobs.
  20. Args:
  21. sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
  22. seq_group: the output is associated with this :class:`SequenceGroup`
  23. output: the :class:`SequenceGroupOutput` for a single scheduler step
  24. """
  25. prompt_logprobs = output.prompt_logprobs
  26. # If this is the first (or only) "chunk" of the prefill, we need
  27. # to prepend None to the list of prompt logprobs. The reason for this
  28. # is that for N prompt tokens, the Sampler will generate N-1 total
  29. # prompt logprobs during prefill since the token at idx 0 will not
  30. # have a logprob associated with it.
  31. if prompt_logprobs is not None:
  32. if not seq_group.prompt_logprobs:
  33. prompt_logprobs = [None] + prompt_logprobs
  34. seq_group.prompt_logprobs = []
  35. assert hasattr(sg_output_proc, 'detokenizer')
  36. if (seq_group.sampling_params.detokenize
  37. and sg_output_proc.detokenizer):
  38. sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
  39. seq_group,
  40. prompt_logprobs,
  41. position_offset=len(seq_group.prompt_logprobs))
  42. seq_group.prompt_logprobs.extend(prompt_logprobs)
  43. class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
  44. """SequenceGroupOutputProcessor which handles "output processing" logic,
  45. which happens after the model returns generated token ids and before
  46. scheduling of the next batch. Output processing logic includes
  47. detokenization, and determining if a sequence is finished (e.g. via max len
  48. or eos token).
  49. The SingleStepOutputProcessor is specialized to the case where the model
  50. emits at most a single token per invocation, which precludes configurations
  51. such as speculative decoding or multi-step decoding. This enables beam
  52. search sampling, which requires forking/finishing/freeing sequences in a way
  53. that is currently difficult to schedule multiple steps ahead of time.
  54. """
  55. def __init__(self, scheduler_config: SchedulerConfig,
  56. detokenizer: Detokenizer, scheduler: List[Scheduler],
  57. seq_counter: Counter, stop_checker: StopChecker):
  58. self.scheduler_config = scheduler_config
  59. self.detokenizer = detokenizer
  60. self.scheduler = scheduler
  61. self.seq_counter = seq_counter
  62. self.stop_checker = stop_checker
  63. def process_outputs(self, sequence_group: SequenceGroup,
  64. outputs: List[SequenceGroupOutput],
  65. is_async: bool) -> None:
  66. """Append all new tokens to sequences in the sequence group. Fork any
  67. surviving beam candidates; free any unsurviving ones.
  68. Invokes detokenizer to detokenize new tokens, and also marks sequences
  69. as finished if they meet stop conditions.
  70. is_async - Indicates whether this postprocessor runs in
  71. parallel with the GPU forward pass and is processing
  72. tokens from the previous step. If this is true, then
  73. no tokens need to be appended since it is already done
  74. externally (before the next schedule() call)
  75. """
  76. assert (len(outputs) == 1
  77. ), f"{type(self)} does not support multiple outputs per step"
  78. return self._process_sequence_group_outputs(sequence_group, outputs[0],
  79. is_async)
  80. def process_prompt_logprob(self, seq_group: SequenceGroup,
  81. outputs: List[SequenceGroupOutput]) -> None:
  82. """Process prompt logprobs associated with one step of a single-step-
  83. scheduled computation.
  84. Args:
  85. seq_group: the output is associated with this :class:`SequenceGroup`
  86. output: the :class:`SequenceGroupOutput` for a single scheduler step
  87. """
  88. assert len(outputs) == 1, ("Single step should only has 1 output.")
  89. output = outputs[0]
  90. single_step_process_prompt_logprob(self, seq_group, output)
  91. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  92. outputs: SequenceGroupOutput,
  93. is_async: bool) -> None:
  94. sampling_params = seq_group.sampling_params
  95. if sampling_params.n == 1 and not sampling_params.use_beam_search:
  96. # only have one output sample
  97. sample = outputs.samples[0]
  98. # only have one sequence
  99. seq = seq_group.seqs[0]
  100. if not is_async:
  101. seq.append_token_id(sample.output_token, sample.logprobs)
  102. if sampling_params.detokenize and self.detokenizer:
  103. new_char_count = self.detokenizer.decode_sequence_inplace(
  104. seq, sampling_params)
  105. else:
  106. new_char_count = 0
  107. self.stop_checker.maybe_stop_sequence(
  108. seq,
  109. new_char_count,
  110. sampling_params,
  111. lora_req=seq_group.lora_request,
  112. )
  113. if seq.is_finished():
  114. for scheduler in self.scheduler:
  115. scheduler.free_seq(seq)
  116. return
  117. # TODO: Add support for async for beam search
  118. assert not is_async
  119. # Process samples
  120. samples = outputs.samples
  121. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  122. existing_finished_seqs = seq_group.get_finished_seqs()
  123. parent_child_dict: Dict[int, List[SequenceOutput]] = {
  124. parent_seq.seq_id: []
  125. for parent_seq in parent_seqs
  126. }
  127. for sample in samples:
  128. # Guard against a KeyError which can occur if the request was
  129. # aborted while the output was generated
  130. if (child_list :=
  131. parent_child_dict.get(sample.parent_seq_id)) is not None:
  132. child_list.append(sample)
  133. # List of (child, parent)
  134. child_seqs: List[Tuple[Sequence, Sequence]] = []
  135. # Process the child samples for each parent sequence
  136. for parent in parent_seqs:
  137. child_samples: List[SequenceOutput] = parent_child_dict[
  138. parent.seq_id]
  139. if len(child_samples) == 0:
  140. # This parent sequence has no children samples. Remove
  141. # the parent sequence from the sequence group since it will
  142. # not be used in the future iterations.
  143. parent.status = SequenceStatus.FINISHED_ABORTED
  144. seq_group.remove(parent.seq_id)
  145. for scheduler in self.scheduler:
  146. scheduler.free_seq(parent)
  147. continue
  148. # Fork the parent sequence if there are multiple child samples.
  149. for child_sample in child_samples[:-1]:
  150. new_child_seq_id: int = next(self.seq_counter)
  151. child = parent.fork(new_child_seq_id)
  152. child.append_token_id(child_sample.output_token,
  153. child_sample.logprobs)
  154. child_seqs.append((child, parent))
  155. # Continue the parent sequence for the last child sample.
  156. # We reuse the parent sequence here to reduce redundant memory
  157. # copies, especially when using non-beam search sampling methods.
  158. last_child_sample = child_samples[-1]
  159. parent.append_token_id(last_child_sample.output_token,
  160. last_child_sample.logprobs)
  161. child_seqs.append((parent, parent))
  162. for seq, _ in child_seqs:
  163. if sampling_params.detokenize and self.detokenizer:
  164. new_char_count = self.detokenizer.decode_sequence_inplace(
  165. seq, sampling_params)
  166. else:
  167. new_char_count = 0
  168. self.stop_checker.maybe_stop_sequence(
  169. seq,
  170. new_char_count,
  171. sampling_params,
  172. lora_req=seq_group.lora_request,
  173. )
  174. # Non-beam search case
  175. if not sampling_params.use_beam_search:
  176. # For newly created child sequences, add them to the sequence group
  177. # and fork them in block manager if they are not finished.
  178. for seq, parent in child_seqs:
  179. if seq is not parent:
  180. seq_group.add(seq)
  181. if not seq.is_finished():
  182. for scheduler in self.scheduler:
  183. scheduler.fork_seq(parent, seq)
  184. # Free the finished and selected parent sequences' memory in block
  185. # manager. Keep them in the sequence group as candidate output.
  186. # NOTE: we need to fork the new sequences before freeing the
  187. # old sequences.
  188. for seq, parent in child_seqs:
  189. if seq is parent and seq.is_finished():
  190. for scheduler in self.scheduler:
  191. scheduler.free_seq(seq)
  192. return
  193. # Beam search case
  194. # Select the child sequences to keep in the sequence group.
  195. selected_child_seqs = []
  196. unselected_child_seqs = []
  197. beam_width = sampling_params.best_of
  198. length_penalty = sampling_params.length_penalty
  199. # Select the newly finished sequences with the highest scores
  200. # to replace existing finished sequences.
  201. # Tuple of (seq, parent, is_new)
  202. existing_finished_seqs = [(seq, None, False)
  203. for seq in existing_finished_seqs]
  204. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  205. if seq.is_finished()]
  206. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  207. # Sort the finished sequences by their scores.
  208. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  209. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  210. reverse=True)
  211. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  212. if is_new:
  213. # A newly generated child sequence finishes and has a high
  214. # score, so we will add it into the sequence group.
  215. selected_child_seqs.append((seq, parent))
  216. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  217. if is_new:
  218. # A newly generated child sequence finishes but has a low
  219. # score, so we will not add it into the sequence group.
  220. # Additionally, if this sequence is a continuation of a
  221. # parent sequence, we will need remove the parent sequence
  222. # from the sequence group.
  223. unselected_child_seqs.append((seq, parent))
  224. else:
  225. # An existing finished sequence has a low score, so we will
  226. # remove it from the sequence group.
  227. seq_group.remove(seq.seq_id)
  228. # select the top beam_width sequences from the running
  229. # sequences for the next iteration to continue the beam
  230. # search.
  231. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  232. if not seq.is_finished()]
  233. # Sort the running sequences by their scores.
  234. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  235. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  236. reverse=True)
  237. # Check if we can stop the beam search.
  238. if len(running_child_seqs) == 0:
  239. # No running sequences, stop the beam search.
  240. stop_beam_search = True
  241. elif len(all_finished_seqs) < beam_width:
  242. # Not enough finished sequences, continue the beam search.
  243. stop_beam_search = False
  244. else:
  245. # Check the early stopping criteria
  246. best_running_seq = running_child_seqs[0][0]
  247. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  248. stop_beam_search = self._check_beam_search_early_stopping(
  249. sampling_params.early_stopping, sampling_params,
  250. best_running_seq, current_worst_seq)
  251. if stop_beam_search:
  252. # Stop the beam search and remove all the running sequences from
  253. # the sequence group.
  254. unselected_child_seqs.extend(running_child_seqs)
  255. else:
  256. # Continue the beam search and select the top beam_width sequences
  257. # to continue the beam search.
  258. selected_child_seqs.extend(running_child_seqs[:beam_width])
  259. # The remaining running sequences will not be used in the next
  260. # iteration. Again, if these sequences are continuations of
  261. # parent sequences, we will need to remove the parent sequences
  262. # from the sequence group.
  263. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  264. # For newly created child sequences, add them to the sequence group
  265. # and fork them in block manager if they are not finished.
  266. for seq, parent in selected_child_seqs:
  267. if seq is not parent:
  268. seq_group.add(seq)
  269. if not seq.is_finished():
  270. for scheduler in self.scheduler:
  271. scheduler.fork_seq(parent, seq)
  272. # Free the finished and selected parent sequences' memory in block
  273. # manager. Keep them in the sequence group as candidate output.
  274. for seq, parent in selected_child_seqs:
  275. if seq is parent and seq.is_finished():
  276. for scheduler in self.scheduler:
  277. scheduler.free_seq(seq)
  278. # Remove the unselected parent sequences from the sequence group and
  279. # free their memory in block manager.
  280. for seq, parent in unselected_child_seqs:
  281. if seq is parent:
  282. # Remove the parent sequence if it is not selected for next
  283. # iteration
  284. seq_group.remove(seq.seq_id)
  285. for scheduler in self.scheduler:
  286. scheduler.free_seq(seq)
  287. def _check_beam_search_early_stopping(
  288. self,
  289. early_stopping: Union[bool, str],
  290. sampling_params: SamplingParams,
  291. best_running_seq: Sequence,
  292. current_worst_seq: Sequence,
  293. ) -> bool:
  294. assert sampling_params.use_beam_search
  295. length_penalty = sampling_params.length_penalty
  296. if early_stopping is True:
  297. return True
  298. current_worst_score = current_worst_seq.get_beam_search_score(
  299. length_penalty=length_penalty,
  300. eos_token_id=current_worst_seq.eos_token_id)
  301. if early_stopping is False:
  302. highest_attainable_score = best_running_seq.get_beam_search_score(
  303. length_penalty=length_penalty,
  304. eos_token_id=best_running_seq.eos_token_id)
  305. else:
  306. assert early_stopping == "never"
  307. if length_penalty > 0.0:
  308. # If length_penalty > 0.0, beam search will prefer longer
  309. # sequences. The highest attainable score calculation is
  310. # based on the longest possible sequence length in this case.
  311. max_possible_length = max(
  312. best_running_seq.get_prompt_len() +
  313. sampling_params.max_tokens,
  314. self.scheduler_config.max_model_len)
  315. highest_attainable_score = (
  316. best_running_seq.get_beam_search_score(
  317. length_penalty=length_penalty,
  318. eos_token_id=best_running_seq.eos_token_id,
  319. seq_len=max_possible_length))
  320. else:
  321. # Otherwise, beam search will prefer shorter sequences. The
  322. # highest attainable score calculation is based on the current
  323. # sequence length.
  324. highest_attainable_score = (
  325. best_running_seq.get_beam_search_score(
  326. length_penalty=length_penalty,
  327. eos_token_id=best_running_seq.eos_token_id))
  328. return current_worst_score >= highest_attainable_score