batch_expansion.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. from itertools import chain, count
  2. from typing import Iterator, List, Tuple
  3. import torch
  4. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  5. SequenceData, SequenceGroupMetadata,
  6. SequenceGroupState, get_all_seq_ids)
  7. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  8. SpeculativeScorer,
  9. SpeculativeScores)
  10. from aphrodite.spec_decode.util import (nvtx_range, sampler_output_to_torch,
  11. split_batch_by_proposal_len)
  12. from aphrodite.task_handler.worker_base import WorkerBase
  13. SeqId = int
  14. TargetSeqId = int
  15. TokenId = int
  16. class BatchExpansionTop1Scorer(SpeculativeScorer):
  17. """Implements a speculative scorer that uses batch expansion to get
  18. probabilities of speculative tokens according to the scoring model.
  19. Batch expansion converts a list of sequences and multiple query positions
  20. to a new batch of sequences, each with a single query position. This allows
  21. for MQA-like scoring in speculative decoding without requiring an MQA
  22. kernel.
  23. It is strictly less efficient than MQA scoring.
  24. It only supports scoring the top1 proposal tokens of the proposer, instead
  25. of topk/tree.
  26. """
  27. def __init__(self, scorer_worker: WorkerBase, device: str,
  28. vocab_size: int):
  29. self._scorer_worker = scorer_worker
  30. self._device = device
  31. self._vocab_size = vocab_size
  32. @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
  33. def score_proposals(
  34. self,
  35. execute_model_req: ExecuteModelRequest,
  36. proposals: SpeculativeProposals,
  37. ) -> SpeculativeScores:
  38. """Score the proposed tokens via the scorer model.
  39. This converts each input sequence to a set of k+1 target sequences. The
  40. target sequences have the unique continuations to be scored and a
  41. unique sequence ID that is different from all input sequence ids.
  42. If a speculative sequence length would exceed the max model length, then
  43. no speculation is produced for that sequence.
  44. Args:
  45. execute_model_req: The execution request.
  46. proposals: The speculative proposals to score.
  47. Returns:
  48. SpeculativeScores: The scores of each speculative token, along with
  49. which sequences were ignored during scoring.
  50. """
  51. # TODO: perform this on GPU to remove blocking call.
  52. proposal_lens_list = proposals.proposal_lens.tolist()
  53. proposal_token_ids_list = proposals.proposal_token_ids.tolist()
  54. # Filter the list to ignore -1 proposals.
  55. proposal_token_ids_list_without_skips = [
  56. proposals for proposals in proposal_token_ids_list
  57. if -1 not in proposals
  58. ]
  59. (spec_indices, non_spec_indices, target_seq_group_metadata_list,
  60. num_scoring_tokens) = self._expand_batch(
  61. seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
  62. proposal_token_ids_list=proposal_token_ids_list_without_skips,
  63. proposal_lens_list=proposal_lens_list,
  64. )
  65. target_sampler_output = self._scorer_worker.execute_model(
  66. execute_model_req=execute_model_req.clone(
  67. seq_group_metadata_list=target_seq_group_metadata_list))
  68. assert len(target_sampler_output) == 1, "expected single-step output"
  69. target_sampler_output = target_sampler_output[0]
  70. all_tokens, all_probs, spec_logprobs = self._contract_batch(
  71. contracted_bs=len(execute_model_req.seq_group_metadata_list),
  72. target_sampler_output=target_sampler_output,
  73. proposals=proposals,
  74. num_scoring_tokens=num_scoring_tokens,
  75. non_spec_indices=non_spec_indices,
  76. spec_indices=spec_indices,
  77. k=execute_model_req.num_lookahead_slots,
  78. )
  79. return SpeculativeScores(
  80. probs=all_probs,
  81. token_ids=all_tokens,
  82. logprobs=spec_logprobs,
  83. hidden_states=target_sampler_output.hidden_states,
  84. )
  85. def _expand_batch(
  86. self,
  87. seq_group_metadata_list: List[SequenceGroupMetadata],
  88. proposal_token_ids_list: List[List[TokenId]],
  89. proposal_lens_list: List[int],
  90. ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
  91. """Given the input sequences and potentially multiple corresponding
  92. proposal tokens, create a new batch where each sequence has a single
  93. query token.
  94. """
  95. # Aphrodite currently only supports proposal lens equal to zero or the
  96. # batch proposal len. This adds some complexity (splitting the batch
  97. # into spec and non spec sequences) and should be removed in the
  98. # future. It can be done by supporting per-sequence proposal lens.
  99. spec_seqs, spec_indices = split_batch_by_proposal_len(
  100. seq_group_metadata_list,
  101. proposal_lens_list,
  102. select_proposal_len_zero=False)
  103. non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
  104. seq_group_metadata_list,
  105. proposal_lens_list,
  106. select_proposal_len_zero=True)
  107. target_seq_group_metadata_list = self._create_scoring_model_input(
  108. seq_group_metadata_list=spec_seqs,
  109. proposal_token_ids=proposal_token_ids_list,
  110. # NOTE: We determine the seq ids in the expanded batch using the
  111. # full seq_group_metadata_list, instead of only spec_seqs.
  112. target_seq_ids_iter=self._create_target_seq_id_iterator(
  113. seq_ids=get_all_seq_ids(seq_group_metadata_list)),
  114. )
  115. num_scoring_tokens = len(target_seq_group_metadata_list)
  116. target_seq_group_metadata_list.extend(non_spec_seqs)
  117. return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
  118. num_scoring_tokens)
  119. def _contract_batch(
  120. self, contracted_bs: int, target_sampler_output: SamplerOutput,
  121. proposals: SpeculativeProposals, num_scoring_tokens: int,
  122. non_spec_indices: List[int], spec_indices: List[int],
  123. k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  124. """Contract the expanded batch back into its original size.
  125. This maps the scores of speculative tokens back to their original
  126. sequences.
  127. contracted_bs is the original batch size, and the batch size that the
  128. target_sampler_output will be contracted to.
  129. """
  130. (target_token_ids, target_probs, target_logprobs,
  131. non_spec_target_token_ids, non_spec_target_probs,
  132. non_spec_target_logprobs) = self._split_scoring_output(
  133. target_sampler_output, num_scoring_tokens)
  134. # Map distinct sequences used to score each token
  135. # of shape [batch_size * k + 1] back to [batch_size, k + 1].
  136. expanded_batch_size, k = proposals.proposal_token_ids.shape
  137. # The number of tokens in the expanded batch used for speculation is
  138. # equal to the total expanded batch size minus the number of samples for
  139. # non-speculative sequences.
  140. non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
  141. spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
  142. target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
  143. target_probs = target_probs.reshape(*target_token_ids.shape,
  144. self._vocab_size)
  145. target_logprobs = target_logprobs.reshape(target_probs.shape)
  146. all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
  147. fill_value=-1)
  148. all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
  149. all_logprobs = target_logprobs.new_full(size=all_probs.shape,
  150. fill_value=-float("inf"))
  151. if non_spec_indices:
  152. all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
  153. all_probs[non_spec_indices, :1, :] = non_spec_target_probs
  154. all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
  155. if spec_indices:
  156. all_tokens[spec_indices] = target_token_ids
  157. all_probs[spec_indices] = target_probs
  158. all_logprobs[spec_indices] = target_logprobs
  159. return all_tokens, all_probs, all_logprobs
  160. def _create_scoring_model_input(
  161. self,
  162. seq_group_metadata_list: List[SequenceGroupMetadata],
  163. proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
  164. target_seq_ids_iter: Iterator[TargetSeqId],
  165. ) -> List[SequenceGroupMetadata]:
  166. """Given the original input sequences and proposed tokens from the draft
  167. model, create a list of target sequences that can be used for scoring.
  168. target_seq_ids_iter provides sequence ids for the expanded batch,
  169. fulfilling the requirement that no seq id in the expanded batch is equal
  170. to the seq id in the original batch.
  171. """
  172. if not seq_group_metadata_list:
  173. return []
  174. target_seq_group_metadata = list(
  175. chain.from_iterable(
  176. self._create_target_seq_group_metadata(
  177. seq_group_metadata,
  178. proposal_token_ids,
  179. i,
  180. target_seq_ids_iter,
  181. ) for i, seq_group_metadata in enumerate(
  182. seq_group_metadata_list)))
  183. return target_seq_group_metadata
  184. def _create_target_seq_group_metadata(
  185. self,
  186. input_seq_group_metadata: SequenceGroupMetadata,
  187. proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
  188. batch_index: int,
  189. target_seq_ids_iter: Iterator[TargetSeqId],
  190. ) -> List[SequenceGroupMetadata]:
  191. """Given an input sequence group metadata and a list of draft tokens,
  192. create a list of target SequenceGroupMetadata, one for each
  193. token id that needs to be scored.
  194. Naive speculative decoding requires K target model scores, one for each
  195. draft model token. However one can add a bonus token such that if each
  196. token is accepted, then a final token may be sampled from the model.
  197. This function creates K+1 target SequenceGroupMetadata to take
  198. advantage of the bonus token.
  199. """
  200. assert not input_seq_group_metadata.is_prompt, (
  201. "Speculating on "
  202. "prompts not yet supported")
  203. assert len(input_seq_group_metadata.seq_data) == 1, (
  204. "Beam search "
  205. "not supported in speculative decoding")
  206. input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
  207. token_ids_to_score = self._get_token_ids_to_score(
  208. proposal_token_ids[batch_index])
  209. target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
  210. for token_ids in token_ids_to_score:
  211. target_seq_group_metadata_list.append(
  212. self._create_single_target_seq_group_metadata(
  213. input_seq_group_metadata,
  214. input_seq_id,
  215. next(target_seq_ids_iter),
  216. token_ids,
  217. ))
  218. return target_seq_group_metadata_list
  219. def _create_single_target_seq_group_metadata(
  220. self,
  221. seq_group_metadata: SequenceGroupMetadata,
  222. seq_id: SeqId,
  223. target_seq_id: TargetSeqId,
  224. token_ids: List[TokenId],
  225. ) -> SequenceGroupMetadata:
  226. """Create a single target SequenceGroupMetadata.
  227. Args:
  228. seq_group_metadata: The metadata for the input sequence.
  229. seq_id: The input sequence ID.
  230. target_seq_id: The corresponding target sequence ID.
  231. token_ids: The list of token ids that are to be appended to the
  232. input sequence.
  233. """
  234. seq_data = seq_group_metadata.seq_data[seq_id]
  235. prompt_token_ids = seq_data.get_prompt_token_ids()
  236. new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
  237. new_seq_data_dict = {
  238. target_seq_id:
  239. SequenceData(
  240. prompt_token_ids=prompt_token_ids,
  241. output_token_ids=new_output_token_ids,
  242. ),
  243. }
  244. # This is a hack. Technically, spec decoding should compute
  245. # num_lookahead slots at one shot, but instead, it expands the batch
  246. # and evaluate one by one right now. context_len is seq_len - 1 because
  247. # the kv cache is filled by a previous batch in the batch expansion.
  248. for data in new_seq_data_dict.values():
  249. data.update_num_computed_tokens(data.get_len() - 1)
  250. if (seq_group_metadata.state is not None
  251. and seq_group_metadata.state.generator is not None):
  252. generator = torch.Generator(
  253. device=seq_group_metadata.state.generator.device)
  254. generator.set_state(seq_group_metadata.state.generator.get_state())
  255. state = SequenceGroupState(generator=generator)
  256. else:
  257. state = None
  258. return SequenceGroupMetadata(
  259. request_id=seq_group_metadata.request_id,
  260. is_prompt=seq_group_metadata.is_prompt,
  261. seq_data=new_seq_data_dict,
  262. sampling_params=seq_group_metadata.sampling_params,
  263. block_tables={
  264. target_seq_id: seq_group_metadata.block_tables[seq_id],
  265. },
  266. lora_request=None,
  267. token_chunk_size=1,
  268. state=state,
  269. )
  270. def _split_scoring_output(
  271. self, sampler_output: SamplerOutput, num_scoring_tokens: int
  272. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  273. torch.Tensor, torch.Tensor]:
  274. """Split the target model output into speculative and non-speculative
  275. output.
  276. """
  277. # Aphrodite currently only supports proposal lens equal to zero or the
  278. # batch proposal len. This adds some complexity (splitting the batch
  279. # into spec and non spec sequences) and should be removed in the
  280. # future. It can be done by supporting per-sequence proposal lens.
  281. # First samples are from speculative scoring, latter samples are non-
  282. # speculative samples.
  283. split_sizes = [
  284. num_scoring_tokens,
  285. sampler_output.sampled_token_ids.numel() - num_scoring_tokens
  286. ]
  287. (spec_probs, non_spec_probs
  288. ) = sampler_output.sampled_token_probs.split(split_sizes)
  289. (spec_sampled_tokens, non_spec_sampled_tokens
  290. ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
  291. (
  292. spec_logprobs,
  293. non_spec_logprobs,
  294. ) = sampler_output.logprobs.split(split_sizes)
  295. # Convert scores to tensors.
  296. sampler_output.sampled_token_probs = spec_probs
  297. sampler_output.sampled_token_ids = spec_sampled_tokens
  298. sampler_output.logprobs = spec_logprobs
  299. (target_token_ids, target_probs,
  300. target_logprobs) = sampler_output_to_torch([sampler_output], True)
  301. # Convert non-speculative output tokens to tensors.
  302. sampler_output.sampled_token_probs = non_spec_probs
  303. sampler_output.sampled_token_ids = non_spec_sampled_tokens
  304. sampler_output.logprobs = non_spec_logprobs
  305. (non_spec_target_token_ids, non_spec_target_probs,
  306. non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
  307. True)
  308. return (target_token_ids, target_probs, target_logprobs,
  309. non_spec_target_token_ids, non_spec_target_probs,
  310. non_spec_target_logprobs)
  311. def _create_target_seq_id_iterator(
  312. self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
  313. """Create an iterator for creating target sequence ids.
  314. Target sequence ids are distinct from sequence ids because we create a
  315. distinct target sequence id for each proposal token to be scored.
  316. This implementation increments a counter starting at 1 + max of all
  317. provided input sequence ids.
  318. """
  319. return count(start=max(seq_ids) + 1)
  320. def _get_token_ids_to_score(
  321. self,
  322. full_spec_token_ids: List[TokenId] # shape: [k]
  323. ) -> List[List[TokenId]]:
  324. """Given an int tensor of proposal token ids, return a list of
  325. token ids that should be scored.
  326. Returns k+1 output lists. The additional one is used for generating the
  327. bonus token.
  328. Example:
  329. Input: [0, 1, 2, 3] (k=4)
  330. Output: (k+1 lists)
  331. []
  332. [0]
  333. [0, 1]
  334. [0, 1, 2]
  335. [0, 1, 2, 3]
  336. """
  337. empty_token_ids: List[TokenId] = []
  338. token_ids_to_score = [empty_token_ids]
  339. token_ids_to_score.extend([
  340. full_spec_token_ids[:i + 1]
  341. for i in range(len(full_spec_token_ids))
  342. ])
  343. return token_ids_to_score