top1_proposer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  4. SequenceGroupMetadata)
  5. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  6. SpeculativeProposer)
  7. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  8. from aphrodite.spec_decode.util import sampler_output_to_torch
  9. class Top1Proposer(SpeculativeProposer):
  10. """Helper class which separates out sequences which would exceed the max
  11. model length when speculated upon.
  12. This allows combinations of models such as JackFram/llama-68m draft with
  13. meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
  14. 2048 while Llama2-13b has max_position_embeddings of 4096.
  15. We treat the sequences which exceed the proposal draft model length as
  16. "non-spec sequences". Essentially they skip the draft model and go through
  17. normal decoding in the target model.
  18. Currently, only proposal_lens of 0 and k are supported, where k is a global
  19. batch proposal length. In the future Aphrodite should support per-sequence
  20. proposal lengths.
  21. """
  22. def __init__(
  23. self,
  24. worker: ProposerWorkerBase,
  25. device: str,
  26. vocab_size: int,
  27. max_proposal_len: Optional[int] = None,
  28. ):
  29. self._worker = worker
  30. self._device = device
  31. self.max_proposal_len = max_proposal_len
  32. self._vocab_size = vocab_size
  33. def get_spec_proposals(
  34. self,
  35. execute_model_req: ExecuteModelRequest,
  36. ) -> SpeculativeProposals:
  37. """Get speculative proposals given the input batch.
  38. Sequences which would exceed the max model length are skipped during
  39. speculation.
  40. """
  41. proposal_len = execute_model_req.num_lookahead_slots
  42. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  43. # Split speculative- and non-speculative- sequences.
  44. (
  45. proposal_lens,
  46. nonzero_proposal_len_seqs,
  47. nonzero_proposal_len_indices,
  48. ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
  49. if nonzero_proposal_len_seqs:
  50. # Speculate tokens using the draft worker for the speculative
  51. # sequences.
  52. # If sampler_transposed is true, then maybe_sampler_output's
  53. # token_ids is like [batch] format in proposal_len size list,
  54. # while if it is false, the format would be [proposal_len]
  55. # in batch size list
  56. nonzero_execute_model_req = ExecuteModelRequest(
  57. seq_group_metadata_list=nonzero_proposal_len_seqs,
  58. num_lookahead_slots=proposal_len,
  59. )
  60. maybe_sampler_output, transposed = self._worker.sampler_output(
  61. execute_model_req=nonzero_execute_model_req,
  62. sample_len=proposal_len,
  63. )
  64. (
  65. proposal_lens,
  66. maybe_sampler_output,
  67. nonzero_proposal_len_indices,
  68. ) = self._remove_no_proposal_seqs(proposal_lens,
  69. maybe_sampler_output,
  70. nonzero_proposal_len_indices,
  71. transposed)
  72. else:
  73. # If no sequences can be speculated, set sampler output to None.
  74. maybe_sampler_output = None
  75. transposed = False
  76. # Combine speculative- and non-speculative sequences into the same
  77. # representation.
  78. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
  79. batch_size=len(seq_group_metadata_list),
  80. proposal_len=proposal_len,
  81. maybe_sampler_output=maybe_sampler_output,
  82. proposal_lens=proposal_lens,
  83. nonzero_proposal_len_indices=nonzero_proposal_len_indices,
  84. sampler_transposed=transposed,
  85. )
  86. proposals = SpeculativeProposals(
  87. proposal_token_ids=proposal_tokens,
  88. proposal_probs=proposal_probs,
  89. proposal_lens=proposal_lens,
  90. )
  91. return proposals
  92. def _split_by_proposal_len(
  93. self,
  94. seq_group_metadata_list: List[SequenceGroupMetadata],
  95. proposal_len: int,
  96. ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
  97. """Split sequences by two groups:
  98. 1. Sequences with non-zero proposal length.
  99. 2. Sequences with zero proposal length (due to disabled speculation
  100. or exceed the maximum model length).
  101. """
  102. proposal_lens: List[int] = []
  103. nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
  104. nonzero_proposal_len_indices: List[int] = []
  105. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  106. # The speculative decoding for this request has been disabled
  107. # (e.g. due to high traffic).
  108. if seq_group_metadata.num_speculative_tokens == 0:
  109. proposal_lens.append(0)
  110. continue
  111. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  112. seq_len = seq_data.get_len()
  113. # Currently only proposal lens of 0 or the global batch proposal len
  114. # are supported.
  115. # If max_proposal_len is defined, then we shall no exccess this
  116. # quota for nonzero_proposal
  117. new_k = 0
  118. if (self.max_proposal_len is None
  119. or seq_len + proposal_len < self.max_proposal_len):
  120. new_k = proposal_len
  121. nonzero_proposal_len_seqs.append(seq_group_metadata)
  122. nonzero_proposal_len_indices.append(i)
  123. proposal_lens.append(new_k)
  124. seq_group_metadata.num_speculative_tokens = new_k
  125. return (
  126. proposal_lens,
  127. nonzero_proposal_len_seqs,
  128. nonzero_proposal_len_indices,
  129. )
  130. @staticmethod
  131. def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
  132. nonzero_proposal_len_indices, transposed):
  133. """Remove sequences from nonzero_proposal_len_indices and reset
  134. their proposal_len to 0 the draft worker does not provide a proposal
  135. (maybe_sampler_output=None). This can avoid scoring overheads.
  136. """
  137. # If maybe_sampler_output is None, then the draft worker did not
  138. # provide a proposal for any sequence and thus no action needed.
  139. # Also we do not support transposed maybe_sampler_output for now
  140. # because it seems not straightforward for draft workers outputting
  141. # transposed sampler outputs to handle the case of no proposal.
  142. if maybe_sampler_output is None or transposed:
  143. return (proposal_lens, maybe_sampler_output,
  144. nonzero_proposal_len_indices)
  145. new_proposal_lens: List[int] = []
  146. new_nonzero_proposal_len_indices: List[int] = []
  147. new_maybe_sampler_output: List[SamplerOutput] = []
  148. nonzero_proposal_len_idx_ptr = 0
  149. seq_idx = 0
  150. while seq_idx < len(
  151. proposal_lens) and nonzero_proposal_len_idx_ptr < len(
  152. nonzero_proposal_len_indices):
  153. if seq_idx < nonzero_proposal_len_indices[
  154. nonzero_proposal_len_idx_ptr]:
  155. # Sequence is not in the original nonzero_proposal_len_indices,
  156. # meaning that it has a proposal length of 0 before sending to
  157. # the draft worker.
  158. assert proposal_lens[seq_idx] == 0
  159. new_proposal_lens.append(0)
  160. else:
  161. # Sequence is in the original nonzero_proposal_len_indices
  162. if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
  163. # but does not have a proposal from the draft worker.
  164. new_proposal_lens.append(0)
  165. else:
  166. # and has a proposal from the draft worker. Add it to the
  167. # new nonzero proposal list and keep the sampler output.
  168. new_proposal_lens.append(proposal_lens[seq_idx])
  169. new_nonzero_proposal_len_indices.append(seq_idx)
  170. new_maybe_sampler_output.append(
  171. maybe_sampler_output[nonzero_proposal_len_idx_ptr])
  172. nonzero_proposal_len_idx_ptr += 1
  173. seq_idx += 1
  174. # The remaining sequences should have proposal length of 0.
  175. new_proposal_lens.extend(proposal_lens[seq_idx:])
  176. # We assume sampler_output will not be a list of all Nones.
  177. # In this case this function should not be called.
  178. assert new_maybe_sampler_output
  179. return (new_proposal_lens, new_maybe_sampler_output,
  180. new_nonzero_proposal_len_indices)
  181. def _merge_outputs(
  182. self,
  183. batch_size: int,
  184. proposal_len: int,
  185. maybe_sampler_output: Optional[List[SamplerOutput]],
  186. proposal_lens: List[int],
  187. nonzero_proposal_len_indices: List[int],
  188. sampler_transposed: bool,
  189. ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
  190. """After speculations are produced, merge the speculation results with
  191. the skipped sequences.
  192. """
  193. if maybe_sampler_output is None:
  194. # If no speculative tokens, the sampler output will be None.
  195. # In this case we return empty proposals.
  196. proposal_tokens = torch.tensor(-1,
  197. dtype=torch.long,
  198. device=self._device).expand(
  199. batch_size, proposal_len)
  200. proposal_probs = torch.tensor(0,
  201. dtype=torch.float32,
  202. device=self._device).expand(
  203. batch_size, proposal_len,
  204. self._vocab_size)
  205. proposal_lens_tensor = torch.tensor(0,
  206. dtype=torch.long,
  207. device=self._device).expand(
  208. len(proposal_lens))
  209. return proposal_tokens, proposal_probs, proposal_lens_tensor
  210. sampler_output = maybe_sampler_output
  211. proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
  212. sampler_output, sampler_transposed)
  213. # Now, reformat the output GPU tensors such that each sequence has
  214. # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
  215. entire_proposal_tokens = proposal_tokens.new_full(
  216. size=(batch_size, *proposal_tokens.shape[1:]),
  217. fill_value=-1,
  218. )
  219. entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
  220. entire_proposal_probs = proposal_probs.new_zeros(
  221. batch_size,
  222. *proposal_probs.shape[1:],
  223. )
  224. entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
  225. proposal_tokens, proposal_probs = (
  226. entire_proposal_tokens,
  227. entire_proposal_probs,
  228. )
  229. proposal_lens_tensor = torch.zeros(batch_size,
  230. dtype=torch.long,
  231. device=self._device)
  232. proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
  233. return proposal_tokens, proposal_probs, proposal_lens_tensor