top1_proposer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from typing import List, Optional, Set, Tuple
  2. import torch
  3. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  4. SequenceGroupMetadata)
  5. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  6. SpeculativeProposer)
  7. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  8. from aphrodite.spec_decode.util import sampler_output_to_torch
  9. class Top1Proposer(SpeculativeProposer):
  10. """Helper class which separates out sequences which would exceed the max
  11. model length when speculated upon.
  12. This allows combinations of models such as JackFram/llama-68m draft with
  13. meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
  14. 2048 while Llama2-13b has max_position_embeddings of 4096.
  15. We treat the sequences which exceed the proposal draft model length as
  16. "non-spec sequences". Essentially they skip the draft model and go through
  17. normal decoding in the target model.
  18. Currently, only proposal_lens of 0 and k are supported, where k is a global
  19. batch proposal length. In the future Aphrodite should support per-sequence
  20. proposal lengths.
  21. """
  22. def __init__(
  23. self,
  24. worker: ProposerWorkerBase,
  25. device: str,
  26. vocab_size: int,
  27. max_proposal_len: Optional[int] = None,
  28. ):
  29. self._worker = worker
  30. self._device = device
  31. self.max_proposal_len = max_proposal_len
  32. self._vocab_size = vocab_size
  33. def get_spec_proposals(
  34. self,
  35. execute_model_req: ExecuteModelRequest,
  36. seq_ids_with_bonus_token_in_last_step: Set[int],
  37. ) -> SpeculativeProposals:
  38. """Get speculative proposals given the input batch.
  39. Sequences which would exceed the max model length are skipped during
  40. speculation.
  41. """
  42. proposal_len = execute_model_req.num_lookahead_slots
  43. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  44. # Split speculative- and non-speculative- sequences.
  45. (
  46. proposal_lens,
  47. nonzero_proposal_len_seqs,
  48. nonzero_proposal_len_indices,
  49. ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
  50. if nonzero_proposal_len_seqs:
  51. # Speculate tokens using the draft worker for the speculative
  52. # sequences.
  53. # If sampler_transposed is true, then maybe_sampler_output's
  54. # token_ids is like [batch] format in proposal_len size list,
  55. # while if it is false, the format would be [proposal_len]
  56. # in batch size list
  57. hidden_states = execute_model_req.previous_hidden_states
  58. if hidden_states is not None:
  59. hidden_states.prune(nonzero_proposal_len_seqs)
  60. nonzero_execute_model_req = ExecuteModelRequest(
  61. seq_group_metadata_list=nonzero_proposal_len_seqs,
  62. num_lookahead_slots=proposal_len,
  63. previous_hidden_states=hidden_states,
  64. )
  65. maybe_sampler_output, transposed = self._worker.sampler_output(
  66. execute_model_req=nonzero_execute_model_req,
  67. sample_len=proposal_len,
  68. seq_ids_with_bonus_token_in_last_step=\
  69. seq_ids_with_bonus_token_in_last_step,
  70. )
  71. (
  72. proposal_lens,
  73. maybe_sampler_output,
  74. nonzero_proposal_len_indices,
  75. ) = self._remove_no_proposal_seqs(proposal_lens,
  76. maybe_sampler_output,
  77. nonzero_proposal_len_indices,
  78. transposed)
  79. else:
  80. # If no sequences can be speculated, set sampler output to None.
  81. maybe_sampler_output = None
  82. transposed = False
  83. # Combine speculative- and non-speculative sequences into the same
  84. # representation.
  85. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
  86. batch_size=len(seq_group_metadata_list),
  87. proposal_len=proposal_len,
  88. maybe_sampler_output=maybe_sampler_output,
  89. proposal_lens=proposal_lens,
  90. nonzero_proposal_len_indices=nonzero_proposal_len_indices,
  91. sampler_transposed=transposed,
  92. )
  93. proposals = SpeculativeProposals(
  94. proposal_token_ids=proposal_tokens,
  95. proposal_probs=proposal_probs,
  96. proposal_lens=proposal_lens,
  97. no_proposals=maybe_sampler_output is None)
  98. return proposals
  99. def _split_by_proposal_len(
  100. self,
  101. seq_group_metadata_list: List[SequenceGroupMetadata],
  102. proposal_len: int,
  103. ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
  104. """Split sequences by two groups:
  105. 1. Sequences with non-zero proposal length.
  106. 2. Sequences with zero proposal length (due to disabled speculation
  107. or exceed the maximum model length).
  108. """
  109. proposal_lens: List[int] = []
  110. nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
  111. nonzero_proposal_len_indices: List[int] = []
  112. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  113. # The speculative decoding for this request has been disabled
  114. # (e.g. due to high traffic).
  115. if seq_group_metadata.num_speculative_tokens == 0:
  116. proposal_lens.append(0)
  117. continue
  118. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  119. seq_len = seq_data.get_len()
  120. # Currently only proposal lens of 0 or the global batch proposal len
  121. # are supported.
  122. # If max_proposal_len is defined, then we shall not exceed this
  123. # quota for nonzero_proposal
  124. new_k = 0
  125. if (self.max_proposal_len is None
  126. or seq_len + proposal_len < self.max_proposal_len):
  127. new_k = proposal_len
  128. nonzero_proposal_len_seqs.append(seq_group_metadata)
  129. nonzero_proposal_len_indices.append(i)
  130. proposal_lens.append(new_k)
  131. seq_group_metadata.num_speculative_tokens = new_k
  132. return (
  133. proposal_lens,
  134. nonzero_proposal_len_seqs,
  135. nonzero_proposal_len_indices,
  136. )
  137. @staticmethod
  138. def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
  139. nonzero_proposal_len_indices, transposed):
  140. """Remove sequences from nonzero_proposal_len_indices and reset
  141. their proposal_len to 0 the draft worker does not provide a proposal
  142. (maybe_sampler_output=None). This can avoid scoring overheads.
  143. """
  144. # If maybe_sampler_output is None, then the draft worker did not
  145. # provide a proposal for any sequence and thus no action needed.
  146. # Also we do not support transposed maybe_sampler_output for now
  147. # because it seems not straightforward for draft workers outputting
  148. # transposed sampler outputs to handle the case of no proposal.
  149. if maybe_sampler_output is None or transposed:
  150. return (proposal_lens, maybe_sampler_output,
  151. nonzero_proposal_len_indices)
  152. new_proposal_lens: List[int] = []
  153. new_nonzero_proposal_len_indices: List[int] = []
  154. new_maybe_sampler_output: List[SamplerOutput] = []
  155. nonzero_proposal_len_idx_ptr = 0
  156. seq_idx = 0
  157. while seq_idx < len(
  158. proposal_lens) and nonzero_proposal_len_idx_ptr < len(
  159. nonzero_proposal_len_indices):
  160. if seq_idx < nonzero_proposal_len_indices[
  161. nonzero_proposal_len_idx_ptr]:
  162. # Sequence is not in the original nonzero_proposal_len_indices,
  163. # meaning that it has a proposal length of 0 before sending to
  164. # the draft worker.
  165. assert proposal_lens[seq_idx] == 0
  166. new_proposal_lens.append(0)
  167. else:
  168. # Sequence is in the original nonzero_proposal_len_indices
  169. if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
  170. # but does not have a proposal from the draft worker.
  171. new_proposal_lens.append(0)
  172. else:
  173. # and has a proposal from the draft worker. Add it to the
  174. # new nonzero proposal list and keep the sampler output.
  175. new_proposal_lens.append(proposal_lens[seq_idx])
  176. new_nonzero_proposal_len_indices.append(seq_idx)
  177. new_maybe_sampler_output.append(
  178. maybe_sampler_output[nonzero_proposal_len_idx_ptr])
  179. nonzero_proposal_len_idx_ptr += 1
  180. seq_idx += 1
  181. # The remaining sequences should have proposal length of 0.
  182. new_proposal_lens.extend(proposal_lens[seq_idx:])
  183. # We assume sampler_output will not be a list of all Nones.
  184. # In this case this function should not be called.
  185. assert new_maybe_sampler_output
  186. return (new_proposal_lens, new_maybe_sampler_output,
  187. new_nonzero_proposal_len_indices)
  188. def _merge_outputs(
  189. self,
  190. batch_size: int,
  191. proposal_len: int,
  192. maybe_sampler_output: Optional[List[SamplerOutput]],
  193. proposal_lens: List[int],
  194. nonzero_proposal_len_indices: List[int],
  195. sampler_transposed: bool,
  196. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  197. """After speculations are produced, merge the speculation results with
  198. the skipped sequences.
  199. """
  200. if maybe_sampler_output is None:
  201. # If no speculative tokens, the sampler output will be None.
  202. # In this case we return empty proposals.
  203. proposal_tokens = torch.tensor(-1,
  204. dtype=torch.long,
  205. device=self._device).expand(
  206. batch_size, proposal_len)
  207. proposal_probs = torch.tensor(0,
  208. dtype=torch.float32,
  209. device=self._device).expand(
  210. batch_size, proposal_len,
  211. self._vocab_size)
  212. proposal_lens_tensor = torch.tensor(0,
  213. dtype=torch.long,
  214. device=self._device).expand(
  215. len(proposal_lens))
  216. return proposal_tokens, proposal_probs, proposal_lens_tensor
  217. sampler_output = maybe_sampler_output
  218. proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
  219. sampler_output, sampler_transposed)
  220. # Now, reformat the output GPU tensors such that each sequence has
  221. # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
  222. entire_proposal_tokens = proposal_tokens.new_full(
  223. size=(batch_size, *proposal_tokens.shape[1:]),
  224. fill_value=-1,
  225. )
  226. entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
  227. entire_proposal_probs = proposal_probs.new_zeros(
  228. batch_size,
  229. *proposal_probs.shape[1:],
  230. )
  231. entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
  232. proposal_tokens, proposal_probs = (
  233. entire_proposal_tokens,
  234. entire_proposal_probs,
  235. )
  236. proposal_lens_tensor = torch.zeros(batch_size,
  237. dtype=torch.long,
  238. device=self._device)
  239. proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
  240. return proposal_tokens, proposal_probs, proposal_lens_tensor