batch_expansion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. from typing import Iterator, List, Tuple, Optional, Dict
  2. from itertools import chain, count
  3. import torch
  4. from aphrodite.common.sequence import (
  5. SamplerOutput,
  6. SequenceGroupMetadata,
  7. SequenceData,
  8. )
  9. from aphrodite.task_handler.worker import Worker
  10. from aphrodite.spec_decode.util import (
  11. nvtx_range,
  12. sampler_output_to_torch,
  13. get_all_seq_ids,
  14. split_batch_by_proposal_len,
  15. )
  16. from aphrodite.spec_decode.interfaces import (
  17. SpeculativeScorer,
  18. SpeculativeProposals,
  19. SpeculativeScores,
  20. )
  21. SeqId = int
  22. TargetSeqId = int
  23. TokenId = int
  24. class BatchExpansionTop1Scorer(SpeculativeScorer):
  25. """Implements a speculative scorer that uses batch expansion to get
  26. probabilities of speculative tokens according to the scoring model.
  27. Batch expansion converts a list of sequences and multiple query positions
  28. to a new batch of sequences, each with a single query position. This allows
  29. for MQA-like scoring in speculative decoding without requiring an MQA
  30. kernel.
  31. It is strictly less efficient than MQA scoring.
  32. It only supports scoring the top1 proposal tokens of the proposer, instead
  33. of topk/tree.
  34. """
  35. def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
  36. self._scorer_worker = scorer_worker
  37. self._device = device
  38. self._vocab_size = vocab_size
  39. @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
  40. def score_proposals(
  41. self,
  42. seq_group_metadata_list: List[SequenceGroupMetadata],
  43. blocks_to_swap_in: Optional[Dict[int, int]],
  44. blocks_to_swap_out: Optional[Dict[int, int]],
  45. blocks_to_copy: Optional[Dict[int, List[int]]],
  46. k: int,
  47. proposals: SpeculativeProposals,
  48. ) -> SpeculativeScores:
  49. """Score the proposed tokens via the scorer model.
  50. This converts each input sequence to a set of k+1 target sequences. The
  51. target sequences have the unique continuations to be scored and a
  52. unique sequence ID that is different from all input sequence ids.
  53. If a speculative sequence length would exceed the max model length, then
  54. no speculation is produced for that sequence.
  55. Args:
  56. seq_group_metadata_list: The input sequence group metadata.
  57. blocks_to_swap_in: This is passed to the worker during scoring.
  58. blocks_to_swap_out: This is passed to the worker during scoring.
  59. blocks_to_copy: This is passed to the worker during scoring.
  60. k: The fixed proposal length.
  61. proposals: The speculative proposals to score.
  62. Returns:
  63. SpeculativeScores: The scores of each speculative token, along with
  64. which sequences were ignored during scoring.
  65. """
  66. # TODO: perform this on GPU to remove blocking call.
  67. proposal_lens_list = proposals.proposal_lens.tolist()
  68. proposal_token_ids_list = proposals.proposal_token_ids.tolist()
  69. (
  70. spec_indices,
  71. non_spec_indices,
  72. target_seq_group_metadata_list,
  73. num_scoring_tokens,
  74. ) = self._expand_batch(
  75. seq_group_metadata_list=seq_group_metadata_list,
  76. proposal_token_ids_list=proposal_token_ids_list,
  77. proposal_lens_list=proposal_lens_list,
  78. )
  79. target_sampler_output = self._scorer_worker.execute_model(
  80. seq_group_metadata_list=target_seq_group_metadata_list,
  81. blocks_to_swap_in=blocks_to_swap_in,
  82. blocks_to_swap_out=blocks_to_swap_out,
  83. blocks_to_copy=blocks_to_copy,
  84. return_python_output=False,
  85. )
  86. all_tokens, all_probs = self._contract_batch(
  87. original_bs=len(seq_group_metadata_list),
  88. target_sampler_output=target_sampler_output,
  89. proposals=proposals,
  90. num_scoring_tokens=num_scoring_tokens,
  91. non_spec_indices=non_spec_indices,
  92. spec_indices=spec_indices,
  93. k=k,
  94. )
  95. return SpeculativeScores(
  96. probs=all_probs,
  97. token_ids=all_tokens,
  98. )
  99. def _expand_batch(
  100. self,
  101. seq_group_metadata_list: List[SequenceGroupMetadata],
  102. proposal_token_ids_list: List[TokenId],
  103. proposal_lens_list: List[int],
  104. ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
  105. """Given the input sequences and potentially multiple corresponding
  106. proposal tokens, create a new batch where each sequence has a single
  107. query token.
  108. """
  109. # Aphrodite currently only supports proposal lens equal to zero or the
  110. # batch proposal len. This adds some complexity (splitting the batch
  111. # into spec and non spec sequences) and should be removed in the
  112. # future. It can be done by supporting per-sequence proposal lens.
  113. spec_seqs, spec_indices = split_batch_by_proposal_len(
  114. seq_group_metadata_list,
  115. proposal_lens_list,
  116. select_proposal_len_zero=False,
  117. )
  118. non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
  119. seq_group_metadata_list,
  120. proposal_lens_list,
  121. select_proposal_len_zero=True,
  122. )
  123. target_seq_group_metadata_list = self._create_scoring_model_input(
  124. spec_seqs, proposal_token_ids_list)
  125. num_scoring_tokens = len(target_seq_group_metadata_list)
  126. target_seq_group_metadata_list.extend(non_spec_seqs)
  127. return (
  128. spec_indices,
  129. non_spec_indices,
  130. target_seq_group_metadata_list,
  131. num_scoring_tokens,
  132. )
  133. def _contract_batch(
  134. self,
  135. original_bs: int,
  136. target_sampler_output: List[SamplerOutput],
  137. proposals: SpeculativeProposals,
  138. num_scoring_tokens: int,
  139. non_spec_indices: List[int],
  140. spec_indices: List[int],
  141. k: int,
  142. ) -> Tuple[torch.Tensor, torch.Tensor]:
  143. """Contract the expanded batch back into its original size.
  144. This maps the scores of speculative tokens back to their original
  145. sequences.
  146. """
  147. (
  148. target_token_ids,
  149. target_probs,
  150. non_spec_target_token_ids,
  151. non_spec_target_probs,
  152. ) = self._split_scoring_output(target_sampler_output,
  153. num_scoring_tokens)
  154. # Map distinct sequences used to score each token
  155. # of shape [batch_size * k + 1] back to [batch_size, k + 1].
  156. batch_size, k = proposals.proposal_token_ids.shape
  157. target_token_ids = target_token_ids.squeeze().reshape(
  158. batch_size, k + 1)
  159. target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
  160. self._vocab_size)
  161. all_tokens = torch.full(
  162. size=(original_bs, k + 1),
  163. fill_value=-1,
  164. device=self._device,
  165. dtype=torch.long,
  166. )
  167. all_probs = torch.zeros(
  168. original_bs,
  169. k + 1,
  170. self._vocab_size,
  171. device=self._device,
  172. dtype=torch.float32,
  173. )
  174. if non_spec_indices:
  175. all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
  176. all_probs[non_spec_indices, :1, :] = non_spec_target_probs
  177. if spec_indices:
  178. all_tokens[spec_indices] = target_token_ids
  179. all_probs[spec_indices] = target_probs
  180. return all_tokens, all_probs
  181. def _create_scoring_model_input(
  182. self,
  183. seq_group_metadata_list: List[SequenceGroupMetadata],
  184. proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
  185. ) -> List[SequenceGroupMetadata]:
  186. """Given the original input sequences and proposed tokens from the draft
  187. model, create a list of target sequences that can be used for scoring.
  188. """
  189. if not seq_group_metadata_list:
  190. return []
  191. target_seq_ids_iter = self._create_target_seq_id_iterator(
  192. get_all_seq_ids(seq_group_metadata_list))
  193. target_seq_group_metadata = list(
  194. chain.from_iterable(
  195. self._create_target_seq_group_metadata(
  196. seq_group_metadata,
  197. proposal_token_ids,
  198. i,
  199. target_seq_ids_iter,
  200. ) for i, seq_group_metadata in enumerate(
  201. seq_group_metadata_list)))
  202. return target_seq_group_metadata
  203. def _create_target_seq_group_metadata(
  204. self,
  205. input_seq_group_metadata: SequenceGroupMetadata,
  206. proposal_token_ids: List[TokenId], # shape: [batch_size, k]
  207. batch_index: int,
  208. target_seq_ids_iter: Iterator[TargetSeqId],
  209. ) -> List[SequenceGroupMetadata]:
  210. """Given an input sequence group metadata and a list of draft tokens,
  211. create a list of target SequenceGroupMetadata, one for each
  212. token id that needs to be scored.
  213. Naive speculative decoding requires K target model scores, one for each
  214. draft model token. However one can add a bonus token such that if each
  215. token is accepted, then a final token may be sampled from the model.
  216. This function creates K+1 target SequenceGroupMetadata to take
  217. advantage of the bonus token.
  218. """
  219. assert not input_seq_group_metadata.is_prompt, (
  220. "Speculating on "
  221. "prompts not yet supported")
  222. assert len(input_seq_group_metadata.seq_data) == 1, (
  223. "Beam search "
  224. "not supported in speculative decoding")
  225. input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
  226. token_ids_to_score = self._get_token_ids_to_score(
  227. proposal_token_ids[batch_index])
  228. target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
  229. for token_ids in token_ids_to_score:
  230. target_seq_group_metadata_list.append(
  231. self._create_single_target_seq_group_metadata(
  232. input_seq_group_metadata,
  233. input_seq_id,
  234. next(target_seq_ids_iter),
  235. token_ids,
  236. ))
  237. return target_seq_group_metadata_list
  238. def _create_single_target_seq_group_metadata(
  239. self,
  240. seq_group_metadata: SequenceGroupMetadata,
  241. seq_id: SeqId,
  242. target_seq_id: TargetSeqId,
  243. token_ids: List[TokenId],
  244. ) -> SequenceGroupMetadata:
  245. """Create a single target SequenceGroupMetadata.
  246. Args:
  247. seq_group_metadata: The metadata for the input sequence.
  248. seq_id: The input sequence ID.
  249. target_seq_id: The corresponding target sequence ID.
  250. token_ids: The list of token ids that are to be appended to the
  251. input sequence.
  252. """
  253. seq_data = seq_group_metadata.seq_data[seq_id]
  254. prompt_token_ids = seq_data.get_prompt_token_ids()
  255. new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
  256. return SequenceGroupMetadata(
  257. request_id=seq_group_metadata.request_id,
  258. is_prompt=seq_group_metadata.is_prompt,
  259. seq_data={
  260. target_seq_id:
  261. SequenceData(
  262. prompt_token_ids=prompt_token_ids,
  263. output_token_ids=new_output_token_ids,
  264. ),
  265. },
  266. sampling_params=seq_group_metadata.sampling_params,
  267. block_tables={
  268. target_seq_id: seq_group_metadata.block_tables[seq_id],
  269. },
  270. lora_request=None,
  271. persistent_data={},
  272. )
  273. def _split_scoring_output(
  274. self, sampler_output: SamplerOutput, num_scoring_tokens: int
  275. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  276. """Split the target model output into speculative and non-speculative
  277. output.
  278. """
  279. # Aphrodite currently only supports proposal lens equal to zero or the
  280. # batch proposal len. This adds some complexity (splitting the batch
  281. # into spec and non spec sequences) and should be removed in the
  282. # future. It can be done by supporting per-sequence proposal lens.
  283. # First samples are from speculative scoring, latter samples are non-
  284. # speculative samples.
  285. split_sizes = [
  286. num_scoring_tokens,
  287. sampler_output.sampled_token_ids.numel() - num_scoring_tokens,
  288. ]
  289. (spec_probs, non_spec_probs
  290. ) = sampler_output.sampled_token_probs.split(split_sizes)
  291. (
  292. spec_sampled_tokens,
  293. non_spec_sampled_tokens,
  294. ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
  295. # Convert scores to tensors.
  296. sampler_output.sampled_token_probs = spec_probs
  297. sampler_output.sampled_token_ids = spec_sampled_tokens
  298. target_token_ids, target_probs = sampler_output_to_torch(
  299. [sampler_output])
  300. # Convert non-speculative output tokens to tensors.
  301. sampler_output.sampled_token_probs = non_spec_probs
  302. sampler_output.sampled_token_ids = non_spec_sampled_tokens
  303. (
  304. non_spec_target_token_ids,
  305. non_spec_target_probs,
  306. ) = sampler_output_to_torch([sampler_output])
  307. return (
  308. target_token_ids,
  309. target_probs,
  310. non_spec_target_token_ids,
  311. non_spec_target_probs,
  312. )
  313. def _create_target_seq_id_iterator(
  314. self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
  315. """Create an iterator for creating target sequence ids.
  316. Target sequence ids are distinct from sequence ids because we create a
  317. distinct target sequence id for each proposal token to be scored.
  318. This implementation increments a counter starting at 1 + max of all
  319. provided input sequence ids.
  320. """
  321. return count(start=max(seq_ids) + 1)
  322. def _get_token_ids_to_score(
  323. self,
  324. full_spec_token_ids: List[TokenId], # shape: [k]
  325. ) -> List[List[TokenId]]:
  326. """Given an int tensor of proposal token ids, return a list of
  327. token ids that should be scored.
  328. Returns k+1 output lists. The additional one is used for generating the
  329. bonus token.
  330. Example:
  331. Input: [0, 1, 2, 3] (k=4)
  332. Output: (k+1 lists)
  333. []
  334. [0]
  335. [0, 1]
  336. [0, 1, 2]
  337. [0, 1, 2, 3]
  338. """
  339. empty_token_ids = []
  340. token_ids_to_score = [empty_token_ids]
  341. token_ids_to_score.extend([
  342. full_spec_token_ids[:i + 1]
  343. for i in range(len(full_spec_token_ids))
  344. ])
  345. return token_ids_to_score