1
0

batch_expansion.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. from array import array
  2. from itertools import chain, count
  3. from typing import Iterator, List, Optional, Tuple
  4. import torch
  5. from aphrodite import SamplingParams
  6. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  7. ExecuteModelRequest, SamplerOutput,
  8. SequenceData, SequenceGroupMetadata,
  9. get_all_seq_ids)
  10. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  11. SpeculativeScorer,
  12. SpeculativeScores)
  13. from aphrodite.spec_decode.util import (nvtx_range, sampler_output_to_torch,
  14. split_batch_by_proposal_len)
  15. from aphrodite.task_handler.worker_base import WorkerBase
  16. SeqId = int
  17. TargetSeqId = int
  18. TokenId = int
  19. DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
  20. class BatchExpansionTop1Scorer(SpeculativeScorer):
  21. """Implements a speculative scorer that uses batch expansion to get
  22. probabilities of speculative tokens according to the scoring model.
  23. Batch expansion converts a list of sequences and multiple query positions
  24. to a new batch of sequences, each with a single query position. This allows
  25. for MQA-like scoring in speculative decoding without requiring an MQA
  26. kernel.
  27. It is strictly less efficient than MQA scoring.
  28. It only supports scoring the top1 proposal tokens of the proposer, instead
  29. of topk/tree.
  30. """
  31. def __init__(self, scorer_worker: WorkerBase, device: str,
  32. vocab_size: int):
  33. self._scorer_worker = scorer_worker
  34. self._device = device
  35. self._vocab_size = vocab_size
  36. @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
  37. def score_proposals(
  38. self,
  39. execute_model_req: ExecuteModelRequest,
  40. proposals: SpeculativeProposals,
  41. ) -> SpeculativeScores:
  42. """Score the proposed tokens via the scorer model.
  43. This converts each input sequence to a set of k+1 target sequences. The
  44. target sequences have the unique continuations to be scored and a
  45. unique sequence ID that is different from all input sequence ids.
  46. If a speculative sequence length would exceed the max model length, then
  47. no speculation is produced for that sequence.
  48. Args:
  49. execute_model_req: The execution request.
  50. proposals: The speculative proposals to score.
  51. Returns:
  52. SpeculativeScores: The scores of each speculative token, along with
  53. which sequences were ignored during scoring.
  54. """
  55. # TODO: perform this on GPU to remove blocking call.
  56. proposal_lens_list = proposals.proposal_lens.tolist()
  57. proposal_token_ids_list = proposals.proposal_token_ids.tolist()
  58. # Filter the list to ignore -1 proposals.
  59. proposal_token_ids_list_without_skips = [
  60. proposals for proposals in proposal_token_ids_list
  61. if -1 not in proposals
  62. ]
  63. (spec_indices, non_spec_indices, target_seq_group_metadata_list,
  64. num_scoring_tokens) = self._expand_batch(
  65. seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
  66. proposal_token_ids_list=proposal_token_ids_list_without_skips,
  67. proposal_lens_list=proposal_lens_list,
  68. )
  69. target_sampler_output = self._scorer_worker.execute_model(
  70. execute_model_req=execute_model_req.clone(
  71. seq_group_metadata_list=target_seq_group_metadata_list))
  72. assert len(target_sampler_output) == 1, "expected single-step output"
  73. target_sampler_output = target_sampler_output[0]
  74. (all_tokens, all_probs, spec_logprobs,
  75. all_hidden_states) = self._contract_batch(
  76. contracted_bs=len(execute_model_req.seq_group_metadata_list),
  77. target_sampler_output=target_sampler_output,
  78. proposals=proposals,
  79. num_scoring_tokens=num_scoring_tokens,
  80. non_spec_indices=non_spec_indices,
  81. spec_indices=spec_indices,
  82. k=execute_model_req.num_lookahead_slots,
  83. )
  84. return SpeculativeScores(
  85. probs=all_probs,
  86. token_ids=all_tokens,
  87. logprobs=spec_logprobs,
  88. hidden_states=all_hidden_states,
  89. )
  90. def _expand_batch(
  91. self,
  92. seq_group_metadata_list: List[SequenceGroupMetadata],
  93. proposal_token_ids_list: List[List[TokenId]],
  94. proposal_lens_list: List[int],
  95. ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
  96. """Given the input sequences and potentially multiple corresponding
  97. proposal tokens, create a new batch where each sequence has a single
  98. query token.
  99. """
  100. # Aphrodite currently only supports proposal lens equal to zero or the
  101. # batch proposal len. This adds some complexity (splitting the batch
  102. # into spec and non spec sequences) and should be removed in the
  103. # future. It can be done by supporting per-sequence proposal lens.
  104. spec_seqs, spec_indices = split_batch_by_proposal_len(
  105. seq_group_metadata_list,
  106. proposal_lens_list,
  107. select_proposal_len_zero=False)
  108. non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
  109. seq_group_metadata_list,
  110. proposal_lens_list,
  111. select_proposal_len_zero=True)
  112. target_seq_group_metadata_list = self._create_scoring_model_input(
  113. seq_group_metadata_list=spec_seqs,
  114. proposal_token_ids=proposal_token_ids_list,
  115. # NOTE: We determine the seq ids in the expanded batch using the
  116. # full seq_group_metadata_list, instead of only spec_seqs.
  117. target_seq_ids_iter=self._create_target_seq_id_iterator(
  118. seq_ids=get_all_seq_ids(seq_group_metadata_list)),
  119. )
  120. num_scoring_tokens = len(target_seq_group_metadata_list)
  121. target_seq_group_metadata_list.extend(non_spec_seqs)
  122. return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
  123. num_scoring_tokens)
  124. def _contract_batch(
  125. self, contracted_bs: int, target_sampler_output: SamplerOutput,
  126. proposals: SpeculativeProposals, num_scoring_tokens: int,
  127. non_spec_indices: List[int], spec_indices: List[int], k: int
  128. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
  129. Optional[torch.Tensor]]:
  130. """Contract the expanded batch back into its original size.
  131. This maps the scores of speculative tokens back to their original
  132. sequences.
  133. contracted_bs is the original batch size, and the batch size that the
  134. target_sampler_output will be contracted to.
  135. """
  136. (target_token_ids, target_probs, target_logprobs, target_hidden_states,
  137. non_spec_target_token_ids, non_spec_target_probs,
  138. non_spec_target_logprobs,
  139. non_spec_target_hidden_states) = self._split_scoring_output(
  140. target_sampler_output, num_scoring_tokens)
  141. # Map distinct sequences used to score each token
  142. # of shape [batch_size * k + 1] back to [batch_size, k + 1].
  143. expanded_batch_size, k = proposals.proposal_token_ids.shape
  144. # The number of tokens in the expanded batch used for speculation is
  145. # equal to the total expanded batch size minus the number of samples for
  146. # non-speculative sequences.
  147. non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
  148. spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
  149. target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
  150. target_probs = target_probs.reshape(*target_token_ids.shape,
  151. self._vocab_size)
  152. target_logprobs = target_logprobs.reshape(target_probs.shape)
  153. if target_hidden_states is not None:
  154. target_hidden_states = target_hidden_states.reshape(
  155. spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
  156. all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
  157. fill_value=-1)
  158. all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
  159. all_logprobs = target_logprobs.new_full(size=all_probs.shape,
  160. fill_value=-float("inf"))
  161. if target_sampler_output.hidden_states is not None:
  162. all_hidden_states = target_hidden_states.new_zeros(
  163. size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
  164. else:
  165. all_hidden_states = None
  166. if non_spec_indices:
  167. all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
  168. all_probs[non_spec_indices, :1, :] = non_spec_target_probs
  169. all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
  170. if all_hidden_states is not None:
  171. all_hidden_states[
  172. non_spec_indices, :1, :] = non_spec_target_hidden_states
  173. if spec_indices:
  174. all_tokens[spec_indices] = target_token_ids
  175. all_probs[spec_indices] = target_probs
  176. all_logprobs[spec_indices] = target_logprobs
  177. if all_hidden_states is not None:
  178. all_hidden_states[spec_indices] = target_hidden_states
  179. return all_tokens, all_probs, all_logprobs, all_hidden_states
  180. def _create_scoring_model_input(
  181. self,
  182. seq_group_metadata_list: List[SequenceGroupMetadata],
  183. proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
  184. target_seq_ids_iter: Iterator[TargetSeqId],
  185. ) -> List[SequenceGroupMetadata]:
  186. """Given the original input sequences and proposed tokens from the draft
  187. model, create a list of target sequences that can be used for scoring.
  188. target_seq_ids_iter provides sequence ids for the expanded batch,
  189. fulfilling the requirement that no seq id in the expanded batch is equal
  190. to the seq id in the original batch.
  191. """
  192. if not seq_group_metadata_list:
  193. return []
  194. target_seq_group_metadata = list(
  195. chain.from_iterable(
  196. self._create_target_seq_group_metadata(
  197. seq_group_metadata,
  198. proposal_token_ids,
  199. i,
  200. target_seq_ids_iter,
  201. ) for i, seq_group_metadata in enumerate(
  202. seq_group_metadata_list)))
  203. return target_seq_group_metadata
  204. def _create_target_seq_group_metadata(
  205. self,
  206. input_seq_group_metadata: SequenceGroupMetadata,
  207. proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
  208. batch_index: int,
  209. target_seq_ids_iter: Iterator[TargetSeqId],
  210. ) -> List[SequenceGroupMetadata]:
  211. """Given an input sequence group metadata and a list of draft tokens,
  212. create a list of target SequenceGroupMetadata, one for each
  213. token id that needs to be scored.
  214. Naive speculative decoding requires K target model scores, one for each
  215. draft model token. However one can add a bonus token such that if each
  216. token is accepted, then a final token may be sampled from the model.
  217. This function creates K+1 target SequenceGroupMetadata to take
  218. advantage of the bonus token.
  219. """
  220. assert not input_seq_group_metadata.is_prompt, (
  221. "Speculating on "
  222. "prompts not yet supported")
  223. assert len(input_seq_group_metadata.seq_data) == 1, (
  224. "Beam search "
  225. "not supported in speculative decoding")
  226. input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
  227. token_ids_to_score = self._get_token_ids_to_score(
  228. proposal_token_ids[batch_index])
  229. # Use simpler sampling parameters apart from for final token
  230. # (in particular don't do seeded sampling) since those sampled tokens
  231. # aren't used.
  232. # We don't replace the sampling_params in the greedy case because
  233. # this also controls whether the probs get modified in the sampler
  234. # (see use of _modify_greedy_probs_inplace there).
  235. sampling_params = input_seq_group_metadata.sampling_params
  236. non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
  237. if sampling_params.temperature else sampling_params
  238. target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
  239. last_index = len(token_ids_to_score) - 1
  240. for i, token_ids in enumerate(token_ids_to_score):
  241. target_sampling_params = sampling_params if i == last_index \
  242. else non_bonus_sampling_params
  243. target_seq_group_metadata_list.append(
  244. self._create_single_target_seq_group_metadata(
  245. input_seq_group_metadata,
  246. input_seq_id,
  247. next(target_seq_ids_iter),
  248. token_ids,
  249. sampling_params=target_sampling_params,
  250. ))
  251. return target_seq_group_metadata_list
  252. @staticmethod
  253. def _create_single_target_seq_group_metadata(
  254. seq_group_metadata: SequenceGroupMetadata,
  255. seq_id: SeqId,
  256. target_seq_id: TargetSeqId,
  257. token_ids: List[TokenId],
  258. sampling_params: SamplingParams,
  259. ) -> SequenceGroupMetadata:
  260. """Create a single target SequenceGroupMetadata.
  261. Args:
  262. seq_group_metadata: The metadata for the input sequence.
  263. seq_id: The input sequence ID.
  264. target_seq_id: The corresponding target sequence ID.
  265. token_ids: The list of token ids that are to be appended to the
  266. input sequence.
  267. """
  268. seq_data = seq_group_metadata.seq_data[seq_id]
  269. prompt_token_ids = seq_data.prompt_token_ids_array
  270. new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
  271. new_seq_data_dict = {
  272. target_seq_id:
  273. SequenceData(
  274. prompt_token_ids,
  275. _output_token_ids=array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  276. new_output_token_ids),
  277. ),
  278. }
  279. # This is a hack. Technically, spec decoding should compute
  280. # num_lookahead slots at one shot, but instead, it expands the batch
  281. # and evaluate one by one right now. context_len is seq_len - 1 because
  282. # the kv cache is filled by a previous batch in the batch expansion.
  283. for data in new_seq_data_dict.values():
  284. data.update_num_computed_tokens(data.get_len() - 1)
  285. return SequenceGroupMetadata(
  286. request_id=seq_group_metadata.request_id,
  287. is_prompt=seq_group_metadata.is_prompt,
  288. seq_data=new_seq_data_dict,
  289. sampling_params=sampling_params,
  290. block_tables={
  291. target_seq_id: seq_group_metadata.block_tables[seq_id],
  292. },
  293. lora_request=None,
  294. token_chunk_size=1,
  295. )
  296. def _split_scoring_output(
  297. self, sampler_output: SamplerOutput, num_scoring_tokens: int
  298. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
  299. Optional[torch.Tensor], torch.Tensor, torch.Tensor,
  300. torch.Tensor, Optional[torch.Tensor]]:
  301. """Split the target model output into speculative and non-speculative
  302. output.
  303. """
  304. # Aphrodite currently only supports proposal lens equal to zero or the
  305. # batch proposal len. This adds some complexity (splitting the batch
  306. # into spec and non spec sequences) and should be removed in the
  307. # future. It can be done by supporting per-sequence proposal lens.
  308. # First samples are from speculative scoring, latter samples are non-
  309. # speculative samples.
  310. split_sizes = [
  311. num_scoring_tokens,
  312. sampler_output.sampled_token_ids.numel() - num_scoring_tokens
  313. ]
  314. (spec_probs, non_spec_probs
  315. ) = sampler_output.sampled_token_probs.split(split_sizes)
  316. (spec_sampled_tokens, non_spec_sampled_tokens
  317. ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
  318. (
  319. spec_logprobs,
  320. non_spec_logprobs,
  321. ) = sampler_output.logprobs.split(split_sizes)
  322. if sampler_output.hidden_states is not None:
  323. (
  324. spec_hidden_states,
  325. non_spec_hidden_states,
  326. ) = sampler_output.hidden_states.split(split_sizes)
  327. else:
  328. spec_hidden_states, non_spec_hidden_states = None, None
  329. # Convert scores to tensors.
  330. sampler_output.sampled_token_probs = spec_probs
  331. sampler_output.sampled_token_ids = spec_sampled_tokens
  332. sampler_output.logprobs = spec_logprobs
  333. sampler_output.hidden_states = spec_hidden_states
  334. (target_token_ids, target_probs, target_logprobs,
  335. target_hidden_states) = sampler_output_to_torch([sampler_output],
  336. True)
  337. # Convert non-speculative output tokens to tensors.
  338. sampler_output.sampled_token_probs = non_spec_probs
  339. sampler_output.sampled_token_ids = non_spec_sampled_tokens
  340. sampler_output.logprobs = non_spec_logprobs
  341. sampler_output.hidden_states = non_spec_hidden_states
  342. (non_spec_target_token_ids, non_spec_target_probs,
  343. non_spec_target_logprobs,
  344. non_spec_target_hidden_states) = sampler_output_to_torch(
  345. [sampler_output], True)
  346. return (target_token_ids, target_probs, target_logprobs,
  347. target_hidden_states, non_spec_target_token_ids,
  348. non_spec_target_probs, non_spec_target_logprobs,
  349. non_spec_target_hidden_states)
  350. def _create_target_seq_id_iterator(
  351. self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
  352. """Create an iterator for creating target sequence ids.
  353. Target sequence ids are distinct from sequence ids because we create a
  354. distinct target sequence id for each proposal token to be scored.
  355. This implementation increments a counter starting at 1 + max of all
  356. provided input sequence ids.
  357. """
  358. return count(start=max(seq_ids) + 1)
  359. def _get_token_ids_to_score(
  360. self,
  361. full_spec_token_ids: List[TokenId] # shape: [k]
  362. ) -> List[List[TokenId]]:
  363. """Given an int tensor of proposal token ids, return a list of
  364. token ids that should be scored.
  365. Returns k+1 output lists. The additional one is used for generating the
  366. bonus token.
  367. Example:
  368. Input: [0, 1, 2, 3] (k=4)
  369. Output: (k+1 lists)
  370. []
  371. [0]
  372. [0, 1]
  373. [0, 1, 2]
  374. [0, 1, 2, 3]
  375. """
  376. empty_token_ids: List[TokenId] = []
  377. token_ids_to_score = [empty_token_ids]
  378. token_ids_to_score.extend([
  379. full_spec_token_ids[:i + 1]
  380. for i in range(len(full_spec_token_ids))
  381. ])
  382. return token_ids_to_score