top1_proposer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  4. SequenceGroupMetadata)
  5. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  6. SpeculativeProposer)
  7. from aphrodite.spec_decode.util import sampler_output_to_torch
  8. from aphrodite.task_handler.worker_base import WorkerBase
  9. class Top1Proposer(SpeculativeProposer):
  10. """Helper class which separates out sequences which would exceed the max
  11. model length when speculated upon.
  12. This allows combinations of models such as JackFram/llama-68m draft with
  13. meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
  14. 2048 while Llama2-13b has max_position_embeddings of 4096.
  15. We treat the sequences which exceed the proposal draft model length as
  16. "non-spec sequences". Essentially they skip the draft model and go through
  17. normal decoding in the target model.
  18. Currently, only proposal_lens of 0 and k are supported, where k is a global
  19. batch proposal length. In the future Aphrodite should support per-sequence
  20. proposal lengths.
  21. """
  22. def __init__(
  23. self,
  24. worker: WorkerBase,
  25. device: str,
  26. vocab_size: int,
  27. max_proposal_len: Optional[int] = None,
  28. ):
  29. self._worker = worker
  30. self._device = device
  31. self.max_proposal_len = max_proposal_len
  32. self._vocab_size = vocab_size
  33. def get_proposals(
  34. self,
  35. execute_model_req: ExecuteModelRequest,
  36. ) -> SpeculativeProposals:
  37. """Get speculative proposals given the input batch.
  38. Sequences which would exceed the max model length are skipped during
  39. speculation.
  40. """
  41. proposal_len = execute_model_req.num_lookahead_slots
  42. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  43. # Split speculative- and non-speculative- sequences.
  44. (
  45. proposal_lens,
  46. nonzero_proposal_len_seqs,
  47. nonzero_proposal_len_indices,
  48. ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
  49. if nonzero_proposal_len_seqs:
  50. # Speculate tokens using the draft worker for the speculative
  51. # sequences.
  52. # If sampler_transposed is true, then maybe_sampler_output's
  53. # token_ids is like [batch] format in proposal_len size list,
  54. # while if it is false, the format would be [proposal_len]
  55. # in batch size list
  56. nonzero_execute_model_req = ExecuteModelRequest(
  57. seq_group_metadata_list=nonzero_proposal_len_seqs,
  58. num_lookahead_slots=proposal_len,
  59. )
  60. maybe_sampler_output, transposed = self._worker.sampler_output(
  61. execute_model_req=nonzero_execute_model_req,
  62. sample_len=proposal_len,
  63. )
  64. (
  65. proposal_lens,
  66. maybe_sampler_output,
  67. nonzero_proposal_len_indices,
  68. ) = self._remove_no_proposal_seqs(proposal_lens,
  69. maybe_sampler_output,
  70. nonzero_proposal_len_indices,
  71. transposed)
  72. else:
  73. # If no sequences can be speculated, set sampler output to None.
  74. maybe_sampler_output = None
  75. transposed = False
  76. # Combine speculative- and non-speculative sequences into the same
  77. # representation.
  78. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
  79. batch_size=len(seq_group_metadata_list),
  80. proposal_len=proposal_len,
  81. maybe_sampler_output=maybe_sampler_output,
  82. proposal_lens=proposal_lens,
  83. nonzero_proposal_len_indices=nonzero_proposal_len_indices,
  84. sampler_transposed=transposed,
  85. )
  86. proposals = SpeculativeProposals(
  87. proposal_token_ids=proposal_tokens,
  88. proposal_probs=proposal_probs,
  89. proposal_lens=proposal_lens,
  90. )
  91. return proposals
  92. def _split_by_proposal_len(
  93. self,
  94. seq_group_metadata_list: List[SequenceGroupMetadata],
  95. proposal_len: int,
  96. ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
  97. """Split sequences by two groups:
  98. 1. Sequences with non-zero proposal length.
  99. 2. Sequences with zero proposal length (due to disabled speculation
  100. or exceed the maximum model length).
  101. """
  102. proposal_lens: List[int] = []
  103. nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
  104. nonzero_proposal_len_indices: List[int] = []
  105. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  106. # The speculative decoding for this request has been disabled
  107. # (e.g. due to high traffic).
  108. if seq_group_metadata.num_speculative_tokens == 0:
  109. proposal_lens.append(0)
  110. continue
  111. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  112. seq_len = seq_data.get_len()
  113. # Currently only proposal lens of 0 or the global batch proposal len
  114. # are supported.
  115. # If max_proposal_len is defined, then we shall no exccess this
  116. # quota for nonzero_proposal
  117. new_k = 0
  118. if (self.max_proposal_len is None
  119. or seq_len + proposal_len < self.max_proposal_len):
  120. new_k = proposal_len
  121. nonzero_proposal_len_seqs.append(seq_group_metadata)
  122. nonzero_proposal_len_indices.append(i)
  123. proposal_lens.append(new_k)
  124. seq_group_metadata.num_speculative_tokens = new_k
  125. return (
  126. proposal_lens,
  127. nonzero_proposal_len_seqs,
  128. nonzero_proposal_len_indices,
  129. )
  130. def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
  131. nonzero_proposal_len_indices, transposed):
  132. """Remove sequences from nonzero_proposal_len_indices and reset
  133. their proposal_len to 0 the draft worker does not provide a proposal
  134. (maybe_sampler_output=None). This can avoid scoring overheads.
  135. """
  136. # If maybe_sampler_output is None, then the draft worker did not
  137. # provide a proposal for any sequence and thus no action needed.
  138. # Also we do not support transposed maybe_sampler_output for now
  139. # because it seems not straightforward for draft workers outputting
  140. # transposed sampler outputs to handle the case of no proposal.
  141. if maybe_sampler_output is None or transposed:
  142. return (proposal_lens, maybe_sampler_output,
  143. nonzero_proposal_len_indices)
  144. new_proposal_lens: List[int] = []
  145. new_nonzero_proposal_len_indices: List[int] = []
  146. new_maybe_sampler_output: List[SamplerOutput] = []
  147. nonzero_proposal_len_idx_ptr = 0
  148. seq_idx = 0
  149. while seq_idx < len(
  150. proposal_lens) and nonzero_proposal_len_idx_ptr < len(
  151. nonzero_proposal_len_indices):
  152. if seq_idx < nonzero_proposal_len_indices[
  153. nonzero_proposal_len_idx_ptr]:
  154. # Sequence is not in the original nonzero_proposal_len_indices,
  155. # meaning that it has a proposal length of 0 before sending to
  156. # the draft worker.
  157. assert proposal_lens[seq_idx] == 0
  158. new_proposal_lens.append(0)
  159. else:
  160. # Sequence is in the original nonzero_proposal_len_indices
  161. if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
  162. # but does not have a proposal from the draft worker.
  163. new_proposal_lens.append(0)
  164. else:
  165. # and has a proposal from the draft worker. Add it to the
  166. # new nonzero proposal list and keep the sampler output.
  167. new_proposal_lens.append(proposal_lens[seq_idx])
  168. new_nonzero_proposal_len_indices.append(seq_idx)
  169. new_maybe_sampler_output.append(
  170. maybe_sampler_output[nonzero_proposal_len_idx_ptr])
  171. nonzero_proposal_len_idx_ptr += 1
  172. seq_idx += 1
  173. # The remaining sequences should have proposal length of 0.
  174. new_proposal_lens.extend(proposal_lens[seq_idx:])
  175. # We assume sampler_output will not be a list of all Nones.
  176. # In this case this function should not be called.
  177. assert new_maybe_sampler_output
  178. return (new_proposal_lens, new_maybe_sampler_output,
  179. new_nonzero_proposal_len_indices)
  180. def _merge_outputs(
  181. self,
  182. batch_size: int,
  183. proposal_len: int,
  184. maybe_sampler_output: Optional[SamplerOutput],
  185. proposal_lens: List[int],
  186. nonzero_proposal_len_indices: List[int],
  187. sampler_transposed: bool,
  188. ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
  189. """After speculations are produced, merge the speculation results with
  190. the skipped sequences.
  191. """
  192. if maybe_sampler_output is None:
  193. # If no speculative tokens, the sampler output will be None.
  194. # In this case we return empty proposals.
  195. proposal_tokens = torch.full(
  196. size=(
  197. batch_size,
  198. proposal_len,
  199. ),
  200. fill_value=-1,
  201. dtype=torch.long,
  202. device=self._device,
  203. )
  204. proposal_probs = torch.zeros(
  205. batch_size,
  206. proposal_len,
  207. self._vocab_size,
  208. dtype=torch.float32,
  209. device=self._device,
  210. )
  211. proposal_lens_tensor = torch.zeros(len(proposal_lens),
  212. dtype=torch.long,
  213. device=self._device)
  214. return proposal_tokens, proposal_probs, proposal_lens_tensor
  215. sampler_output = maybe_sampler_output
  216. proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
  217. sampler_output, sampler_transposed)
  218. # Now, reformat the output GPU tensors such that each sequence has
  219. # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
  220. entire_proposal_tokens = torch.full(
  221. size=(batch_size, *proposal_tokens.shape[1:]),
  222. fill_value=-1,
  223. dtype=torch.long,
  224. device=self._device,
  225. )
  226. entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
  227. entire_proposal_probs = torch.zeros(
  228. batch_size,
  229. *proposal_probs.shape[1:],
  230. dtype=torch.float32,
  231. device=self._device,
  232. )
  233. entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
  234. proposal_tokens, proposal_probs = (
  235. entire_proposal_tokens,
  236. entire_proposal_probs,
  237. )
  238. proposal_lens_tensor = torch.zeros(batch_size,
  239. dtype=torch.long,
  240. device=self._device)
  241. proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
  242. return proposal_tokens, proposal_probs, proposal_lens_tensor