single_step.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. from typing import Dict, List, Tuple, Union
  2. from aphrodite.common.config import SchedulerConfig
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import (Sequence, SequenceGroup,
  5. SequenceGroupOutput, SequenceOutput,
  6. SequenceStatus)
  7. from aphrodite.common.utils import Counter
  8. from aphrodite.engine.output_processor.interfaces import (
  9. SequenceGroupOutputProcessor)
  10. from aphrodite.engine.output_processor.stop_checker import StopChecker
  11. from aphrodite.processing.scheduler import Scheduler
  12. from aphrodite.transformers_utils.detokenizer import Detokenizer
  13. class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
  14. """SequenceGroupOutputProcessor which handles "output processing" logic,
  15. which happens after the model returns generated token ids and before
  16. scheduling of the next batch. Output processing logic includes
  17. detokenization, and determining if a sequence is finished (e.g. via max len
  18. or eos token).
  19. The SingleStepOutputProcessor is specialized to the case where the model
  20. emits at most a single token per invocation, which precludes configurations
  21. such as speculative decoding or multi-step decoding. This enables beam
  22. search sampling, which requires forking/finishing/freeing sequences in a way
  23. that is currently difficult to schedule multiple steps ahead of time.
  24. """
  25. def __init__(self, scheduler_config: SchedulerConfig,
  26. detokenizer: Detokenizer, scheduler: List[Scheduler],
  27. seq_counter: Counter, stop_checker: StopChecker):
  28. self.scheduler_config = scheduler_config
  29. self.detokenizer = detokenizer
  30. self.scheduler = scheduler
  31. self.seq_counter = seq_counter
  32. self.stop_checker = stop_checker
  33. def process_outputs(self, sequence_group: SequenceGroup,
  34. outputs: List[SequenceGroupOutput],
  35. is_async: bool) -> None:
  36. """Append all new tokens to sequences in the sequence group. Fork any
  37. surviving beam candidates; free any unsurviving ones.
  38. Invokes detokenizer to detokenize new tokens, and also marks sequences
  39. as finished if they meet stop conditions.
  40. is_async - Indicates whether this postprocessor runs in
  41. parallel with the GPU forward pass and is processing
  42. tokens from the previous step. If this is true, then
  43. no tokens need to be appended since it is already done
  44. externally (before the next schedule() call)
  45. """
  46. assert (len(outputs) == 1
  47. ), f"{type(self)} does not support multiple outputs per step"
  48. return self._process_sequence_group_outputs(sequence_group, outputs[0],
  49. is_async)
  50. def process_prompt_logprob(self, seq_group: SequenceGroup,
  51. outputs: List[SequenceGroupOutput]) -> None:
  52. assert len(outputs) == 1, ("Single step should only has 1 output.")
  53. output = outputs[0]
  54. prompt_logprobs = output.prompt_logprobs
  55. # If this is the first (or only) "chunk" of the prefill, we need
  56. # to prepend None to the list of prompt logprobs. The reason for this
  57. # is that for N prompt tokens, the Sampler will generate N-1 total
  58. # prompt logprobs during prefill since the token at idx 0 will not
  59. # have a logprob associated with it.
  60. if prompt_logprobs is not None:
  61. if not seq_group.prompt_logprobs:
  62. prompt_logprobs = [None] + prompt_logprobs
  63. seq_group.prompt_logprobs = []
  64. if seq_group.sampling_params.detokenize and self.detokenizer:
  65. self.detokenizer.decode_prompt_logprobs_inplace(
  66. seq_group,
  67. prompt_logprobs,
  68. position_offset=len(seq_group.prompt_logprobs))
  69. seq_group.prompt_logprobs.extend(prompt_logprobs)
  70. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  71. outputs: SequenceGroupOutput,
  72. is_async: bool) -> None:
  73. sampling_params = seq_group.sampling_params
  74. if sampling_params.n == 1 and not sampling_params.use_beam_search:
  75. # only have one output sample
  76. sample = outputs.samples[0]
  77. # only have one sequence
  78. seq = seq_group.seqs[0]
  79. if not is_async:
  80. seq.append_token_id(sample.output_token, sample.logprobs)
  81. if sampling_params.detokenize and self.detokenizer:
  82. new_char_count = self.detokenizer.decode_sequence_inplace(
  83. seq, sampling_params)
  84. else:
  85. new_char_count = 0
  86. self.stop_checker.maybe_stop_sequence(
  87. seq,
  88. new_char_count,
  89. sampling_params,
  90. lora_req=seq_group.lora_request,
  91. )
  92. if seq.is_finished():
  93. for scheduler in self.scheduler:
  94. scheduler.free_seq(seq)
  95. return
  96. # TODO: Add support for async for beam search
  97. assert not is_async
  98. # Process samples
  99. samples = outputs.samples
  100. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  101. existing_finished_seqs = seq_group.get_finished_seqs()
  102. parent_child_dict: Dict[int, List[SequenceOutput]] = {
  103. parent_seq.seq_id: []
  104. for parent_seq in parent_seqs
  105. }
  106. for sample in samples:
  107. # Guard against a KeyError which can occur if the request was
  108. # aborted while the output was generated
  109. if (child_list :=
  110. parent_child_dict.get(sample.parent_seq_id)) is not None:
  111. child_list.append(sample)
  112. # List of (child, parent)
  113. child_seqs: List[Tuple[Sequence, Sequence]] = []
  114. # Process the child samples for each parent sequence
  115. for parent in parent_seqs:
  116. child_samples: List[SequenceOutput] = parent_child_dict[
  117. parent.seq_id]
  118. if len(child_samples) == 0:
  119. # This parent sequence has no children samples. Remove
  120. # the parent sequence from the sequence group since it will
  121. # not be used in the future iterations.
  122. parent.status = SequenceStatus.FINISHED_ABORTED
  123. seq_group.remove(parent.seq_id)
  124. for scheduler in self.scheduler:
  125. scheduler.free_seq(parent)
  126. continue
  127. # Fork the parent sequence if there are multiple child samples.
  128. for child_sample in child_samples[:-1]:
  129. new_child_seq_id: int = next(self.seq_counter)
  130. child = parent.fork(new_child_seq_id)
  131. child.append_token_id(child_sample.output_token,
  132. child_sample.logprobs)
  133. child_seqs.append((child, parent))
  134. # Continue the parent sequence for the last child sample.
  135. # We reuse the parent sequence here to reduce redundant memory
  136. # copies, especially when using non-beam search sampling methods.
  137. last_child_sample = child_samples[-1]
  138. parent.append_token_id(last_child_sample.output_token,
  139. last_child_sample.logprobs)
  140. child_seqs.append((parent, parent))
  141. for seq, _ in child_seqs:
  142. if sampling_params.detokenize and self.detokenizer:
  143. new_char_count = self.detokenizer.decode_sequence_inplace(
  144. seq, sampling_params)
  145. else:
  146. new_char_count = 0
  147. self.stop_checker.maybe_stop_sequence(
  148. seq,
  149. new_char_count,
  150. sampling_params,
  151. lora_req=seq_group.lora_request,
  152. )
  153. # Non-beam search case
  154. if not sampling_params.use_beam_search:
  155. # For newly created child sequences, add them to the sequence group
  156. # and fork them in block manager if they are not finished.
  157. for seq, parent in child_seqs:
  158. if seq is not parent:
  159. seq_group.add(seq)
  160. if not seq.is_finished():
  161. for scheduler in self.scheduler:
  162. scheduler.fork_seq(parent, seq)
  163. # Free the finished and selected parent sequences' memory in block
  164. # manager. Keep them in the sequence group as candidate output.
  165. # NOTE: we need to fork the new sequences before freeing the
  166. # old sequences.
  167. for seq, parent in child_seqs:
  168. if seq is parent and seq.is_finished():
  169. for scheduler in self.scheduler:
  170. scheduler.free_seq(seq)
  171. return
  172. # Beam search case
  173. # Select the child sequences to keep in the sequence group.
  174. selected_child_seqs = []
  175. unselected_child_seqs = []
  176. beam_width = sampling_params.best_of
  177. length_penalty = sampling_params.length_penalty
  178. # Select the newly finished sequences with the highest scores
  179. # to replace existing finished sequences.
  180. # Tuple of (seq, parent, is_new)
  181. existing_finished_seqs = [(seq, None, False)
  182. for seq in existing_finished_seqs]
  183. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  184. if seq.is_finished()]
  185. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  186. # Sort the finished sequences by their scores.
  187. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  188. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  189. reverse=True)
  190. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  191. if is_new:
  192. # A newly generated child sequence finishes and has a high
  193. # score, so we will add it into the sequence group.
  194. selected_child_seqs.append((seq, parent))
  195. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  196. if is_new:
  197. # A newly generated child sequence finishes but has a low
  198. # score, so we will not add it into the sequence group.
  199. # Additionally, if this sequence is a continuation of a
  200. # parent sequence, we will need remove the parent sequence
  201. # from the sequence group.
  202. unselected_child_seqs.append((seq, parent))
  203. else:
  204. # An existing finished sequence has a low score, so we will
  205. # remove it from the sequence group.
  206. seq_group.remove(seq.seq_id)
  207. # select the top beam_width sequences from the running
  208. # sequences for the next iteration to continue the beam
  209. # search.
  210. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  211. if not seq.is_finished()]
  212. # Sort the running sequences by their scores.
  213. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  214. length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
  215. reverse=True)
  216. # Check if we can stop the beam search.
  217. if len(running_child_seqs) == 0:
  218. # No running sequences, stop the beam search.
  219. stop_beam_search = True
  220. elif len(all_finished_seqs) < beam_width:
  221. # Not enough finished sequences, continue the beam search.
  222. stop_beam_search = False
  223. else:
  224. # Check the early stopping criteria
  225. best_running_seq = running_child_seqs[0][0]
  226. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  227. stop_beam_search = self._check_beam_search_early_stopping(
  228. sampling_params.early_stopping, sampling_params,
  229. best_running_seq, current_worst_seq)
  230. if stop_beam_search:
  231. # Stop the beam search and remove all the running sequences from
  232. # the sequence group.
  233. unselected_child_seqs.extend(running_child_seqs)
  234. else:
  235. # Continue the beam search and select the top beam_width sequences
  236. # to continue the beam search.
  237. selected_child_seqs.extend(running_child_seqs[:beam_width])
  238. # The remaining running sequences will not be used in the next
  239. # iteration. Again, if these sequences are continuations of
  240. # parent sequences, we will need to remove the parent sequences
  241. # from the sequence group.
  242. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  243. # For newly created child sequences, add them to the sequence group
  244. # and fork them in block manager if they are not finished.
  245. for seq, parent in selected_child_seqs:
  246. if seq is not parent:
  247. seq_group.add(seq)
  248. if not seq.is_finished():
  249. for scheduler in self.scheduler:
  250. scheduler.fork_seq(parent, seq)
  251. # Free the finished and selected parent sequences' memory in block
  252. # manager. Keep them in the sequence group as candidate output.
  253. for seq, parent in selected_child_seqs:
  254. if seq is parent and seq.is_finished():
  255. for scheduler in self.scheduler:
  256. scheduler.free_seq(seq)
  257. # Remove the unselected parent sequences from the sequence group and
  258. # free their memory in block manager.
  259. for seq, parent in unselected_child_seqs:
  260. if seq is parent:
  261. # Remove the parent sequence if it is not selected for next
  262. # iteration
  263. seq_group.remove(seq.seq_id)
  264. for scheduler in self.scheduler:
  265. scheduler.free_seq(seq)
  266. def _check_beam_search_early_stopping(
  267. self,
  268. early_stopping: Union[bool, str],
  269. sampling_params: SamplingParams,
  270. best_running_seq: Sequence,
  271. current_worst_seq: Sequence,
  272. ) -> bool:
  273. assert sampling_params.use_beam_search
  274. length_penalty = sampling_params.length_penalty
  275. if early_stopping is True:
  276. return True
  277. current_worst_score = current_worst_seq.get_beam_search_score(
  278. length_penalty=length_penalty,
  279. eos_token_id=current_worst_seq.eos_token_id)
  280. if early_stopping is False:
  281. highest_attainable_score = best_running_seq.get_beam_search_score(
  282. length_penalty=length_penalty,
  283. eos_token_id=best_running_seq.eos_token_id)
  284. else:
  285. assert early_stopping == "never"
  286. if length_penalty > 0.0:
  287. # If length_penalty > 0.0, beam search will prefer longer
  288. # sequences. The highest attainable score calculation is
  289. # based on the longest possible sequence length in this case.
  290. max_possible_length = max(
  291. best_running_seq.get_prompt_len() +
  292. sampling_params.max_tokens,
  293. self.scheduler_config.max_model_len)
  294. highest_attainable_score = (
  295. best_running_seq.get_beam_search_score(
  296. length_penalty=length_penalty,
  297. eos_token_id=best_running_seq.eos_token_id,
  298. seq_len=max_possible_length))
  299. else:
  300. # Otherwise, beam search will prefer shorter sequences. The
  301. # highest attainable score calculation is based on the current
  302. # sequence length.
  303. highest_attainable_score = (
  304. best_running_seq.get_beam_search_score(
  305. length_penalty=length_penalty,
  306. eos_token_id=best_running_seq.eos_token_id))
  307. return current_worst_score >= highest_attainable_score