top1_proposer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from typing import List, Optional, Set, Tuple
  2. import torch
  3. from aphrodite.common.sequence import (ExecuteModelRequest,
  4. SequenceGroupMetadata)
  5. from aphrodite.modeling.layers.sampler import SamplerOutput
  6. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  7. SpeculativeProposer)
  8. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  9. from aphrodite.spec_decode.util import sampler_output_to_torch
  10. class Top1Proposer(SpeculativeProposer):
  11. """Helper class which separates out sequences which would exceed the max
  12. model length when speculated upon.
  13. This allows combinations of models such as JackFram/llama-68m draft with
  14. meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
  15. 2048 while Llama2-13b has max_position_embeddings of 4096.
  16. We treat the sequences which exceed the proposal draft model length as
  17. "non-spec sequences". Essentially they skip the draft model and go through
  18. normal decoding in the target model.
  19. Currently, only proposal_lens of 0 and k are supported, where k is a global
  20. batch proposal length. In the future Aphrodite should support per-sequence
  21. proposal lengths.
  22. """
  23. def __init__(
  24. self,
  25. worker: ProposerWorkerBase,
  26. device: str,
  27. vocab_size: int,
  28. max_proposal_len: Optional[int] = None,
  29. ):
  30. self._worker = worker
  31. self._device = device
  32. self.max_proposal_len = max_proposal_len
  33. self._vocab_size = vocab_size
  34. def get_spec_proposals(
  35. self,
  36. execute_model_req: ExecuteModelRequest,
  37. seq_ids_with_bonus_token_in_last_step: Set[int],
  38. ) -> SpeculativeProposals:
  39. """Get speculative proposals given the input batch.
  40. Sequences which would exceed the max model length are skipped during
  41. speculation.
  42. """
  43. proposal_len = execute_model_req.num_lookahead_slots
  44. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  45. # Split speculative- and non-speculative- sequences.
  46. (
  47. proposal_lens,
  48. nonzero_proposal_len_seqs,
  49. nonzero_proposal_len_indices,
  50. ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
  51. if nonzero_proposal_len_seqs:
  52. # Speculate tokens using the draft worker for the speculative
  53. # sequences.
  54. # If sampler_transposed is true, then maybe_sampler_output's
  55. # token_ids is like [batch] format in proposal_len size list,
  56. # while if it is false, the format would be [proposal_len]
  57. # in batch size list
  58. hidden_states = execute_model_req.previous_hidden_states
  59. if hidden_states is not None:
  60. hidden_states.prune(nonzero_proposal_len_seqs)
  61. nonzero_execute_model_req = ExecuteModelRequest(
  62. seq_group_metadata_list=nonzero_proposal_len_seqs,
  63. num_lookahead_slots=proposal_len,
  64. previous_hidden_states=hidden_states,
  65. )
  66. maybe_sampler_output, transposed = self._worker.sampler_output(
  67. execute_model_req=nonzero_execute_model_req,
  68. sample_len=proposal_len,
  69. seq_ids_with_bonus_token_in_last_step=\
  70. seq_ids_with_bonus_token_in_last_step,
  71. )
  72. (
  73. proposal_lens,
  74. maybe_sampler_output,
  75. nonzero_proposal_len_indices,
  76. ) = self._remove_no_proposal_seqs(proposal_lens,
  77. maybe_sampler_output,
  78. nonzero_proposal_len_indices,
  79. transposed)
  80. else:
  81. # If no sequences can be speculated, set sampler output to None.
  82. maybe_sampler_output = None
  83. transposed = False
  84. # Combine speculative- and non-speculative sequences into the same
  85. # representation.
  86. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
  87. batch_size=len(seq_group_metadata_list),
  88. proposal_len=proposal_len,
  89. maybe_sampler_output=maybe_sampler_output,
  90. proposal_lens=proposal_lens,
  91. nonzero_proposal_len_indices=nonzero_proposal_len_indices,
  92. sampler_transposed=transposed,
  93. )
  94. proposals = SpeculativeProposals(
  95. proposal_token_ids=proposal_tokens,
  96. proposal_probs=proposal_probs,
  97. proposal_lens=proposal_lens,
  98. no_proposals=maybe_sampler_output is None)
  99. return proposals
  100. def _split_by_proposal_len(
  101. self,
  102. seq_group_metadata_list: List[SequenceGroupMetadata],
  103. proposal_len: int,
  104. ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
  105. """Split sequences by two groups:
  106. 1. Sequences with non-zero proposal length.
  107. 2. Sequences with zero proposal length (due to disabled speculation
  108. or exceed the maximum model length).
  109. """
  110. proposal_lens: List[int] = []
  111. nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
  112. nonzero_proposal_len_indices: List[int] = []
  113. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  114. # The speculative decoding for this request has been disabled
  115. # (e.g. due to high traffic).
  116. if seq_group_metadata.num_speculative_tokens == 0:
  117. proposal_lens.append(0)
  118. continue
  119. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  120. seq_len = seq_data.get_len()
  121. # Currently only proposal lens of 0 or the global batch proposal len
  122. # are supported.
  123. # If max_proposal_len is defined, then we shall not exceed this
  124. # quota for nonzero_proposal
  125. new_k = 0
  126. if (self.max_proposal_len is None
  127. or seq_len + proposal_len < self.max_proposal_len):
  128. new_k = proposal_len
  129. nonzero_proposal_len_seqs.append(seq_group_metadata)
  130. nonzero_proposal_len_indices.append(i)
  131. proposal_lens.append(new_k)
  132. seq_group_metadata.num_speculative_tokens = new_k
  133. return (
  134. proposal_lens,
  135. nonzero_proposal_len_seqs,
  136. nonzero_proposal_len_indices,
  137. )
  138. @staticmethod
  139. def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
  140. nonzero_proposal_len_indices, transposed):
  141. """Remove sequences from nonzero_proposal_len_indices and reset
  142. their proposal_len to 0 the draft worker does not provide a proposal
  143. (maybe_sampler_output=None). This can avoid scoring overheads.
  144. """
  145. # If maybe_sampler_output is None, then the draft worker did not
  146. # provide a proposal for any sequence and thus no action needed.
  147. # Also we do not support transposed maybe_sampler_output for now
  148. # because it seems not straightforward for draft workers outputting
  149. # transposed sampler outputs to handle the case of no proposal.
  150. if maybe_sampler_output is None or transposed:
  151. return (proposal_lens, maybe_sampler_output,
  152. nonzero_proposal_len_indices)
  153. new_proposal_lens: List[int] = []
  154. new_nonzero_proposal_len_indices: List[int] = []
  155. new_maybe_sampler_output: List[SamplerOutput] = []
  156. nonzero_proposal_len_idx_ptr = 0
  157. seq_idx = 0
  158. while seq_idx < len(
  159. proposal_lens) and nonzero_proposal_len_idx_ptr < len(
  160. nonzero_proposal_len_indices):
  161. if seq_idx < nonzero_proposal_len_indices[
  162. nonzero_proposal_len_idx_ptr]:
  163. # Sequence is not in the original nonzero_proposal_len_indices,
  164. # meaning that it has a proposal length of 0 before sending to
  165. # the draft worker.
  166. assert proposal_lens[seq_idx] == 0
  167. new_proposal_lens.append(0)
  168. else:
  169. # Sequence is in the original nonzero_proposal_len_indices
  170. if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
  171. # but does not have a proposal from the draft worker.
  172. new_proposal_lens.append(0)
  173. else:
  174. # and has a proposal from the draft worker. Add it to the
  175. # new nonzero proposal list and keep the sampler output.
  176. new_proposal_lens.append(proposal_lens[seq_idx])
  177. new_nonzero_proposal_len_indices.append(seq_idx)
  178. new_maybe_sampler_output.append(
  179. maybe_sampler_output[nonzero_proposal_len_idx_ptr])
  180. nonzero_proposal_len_idx_ptr += 1
  181. seq_idx += 1
  182. # The remaining sequences should have proposal length of 0.
  183. new_proposal_lens.extend(proposal_lens[seq_idx:])
  184. # We assume sampler_output will not be a list of all Nones.
  185. # In this case this function should not be called.
  186. assert new_maybe_sampler_output
  187. return (new_proposal_lens, new_maybe_sampler_output,
  188. new_nonzero_proposal_len_indices)
  189. def _merge_outputs(
  190. self,
  191. batch_size: int,
  192. proposal_len: int,
  193. maybe_sampler_output: Optional[List[SamplerOutput]],
  194. proposal_lens: List[int],
  195. nonzero_proposal_len_indices: List[int],
  196. sampler_transposed: bool,
  197. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  198. """After speculations are produced, merge the speculation results with
  199. the skipped sequences.
  200. """
  201. if maybe_sampler_output is None:
  202. # If no speculative tokens, the sampler output will be None.
  203. # In this case we return empty proposals.
  204. proposal_tokens = torch.tensor(-1,
  205. dtype=torch.long,
  206. device=self._device).expand(
  207. batch_size, proposal_len)
  208. proposal_probs = torch.tensor(0,
  209. dtype=torch.float32,
  210. device=self._device).expand(
  211. batch_size, proposal_len,
  212. self._vocab_size)
  213. proposal_lens_tensor = torch.tensor(0,
  214. dtype=torch.long,
  215. device=self._device).expand(
  216. len(proposal_lens))
  217. return proposal_tokens, proposal_probs, proposal_lens_tensor
  218. sampler_output = maybe_sampler_output
  219. proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
  220. sampler_output, sampler_transposed)
  221. # Now, reformat the output GPU tensors such that each sequence has
  222. # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
  223. entire_proposal_tokens = proposal_tokens.new_full(
  224. size=(batch_size, *proposal_tokens.shape[1:]),
  225. fill_value=-1,
  226. )
  227. entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
  228. entire_proposal_probs = proposal_probs.new_zeros(
  229. batch_size,
  230. *proposal_probs.shape[1:],
  231. )
  232. entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
  233. proposal_tokens, proposal_probs = (
  234. entire_proposal_tokens,
  235. entire_proposal_probs,
  236. )
  237. proposal_lens_tensor = torch.zeros(batch_size,
  238. dtype=torch.long,
  239. device=self._device)
  240. proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
  241. return proposal_tokens, proposal_probs, proposal_lens_tensor