top1_proposer.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  4. from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
  5. SpeculativeProposer)
  6. from aphrodite.spec_decode.util import sampler_output_to_torch
  7. from aphrodite.task_handler.worker_base import WorkerBase
  8. class Top1Proposer(SpeculativeProposer):
  9. """Helper class which separates out sequences which would exceed the max
  10. model length when speculated upon.
  11. This allows combinations of models such as JackFram/llama-68m draft with
  12. meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
  13. 2048 while Llama2-13b has max_position_embeddings of 4096.
  14. We treat the sequences which exceed the proposal draft model length as
  15. "non-spec sequences". Essentially they skip the draft model and go through
  16. normal decoding in the target model.
  17. Currently, only proposal_lens of 0 and k are supported, where k is a global
  18. batch proposal length. In the future Aphrodite should support per-sequence
  19. proposal lengths.
  20. """
  21. def __init__(
  22. self,
  23. worker: WorkerBase,
  24. device: str,
  25. vocab_size: int,
  26. max_proposal_len: Optional[int] = None,
  27. ):
  28. self._worker = worker
  29. self._device = device
  30. self.max_proposal_len = max_proposal_len
  31. self._vocab_size = vocab_size
  32. def get_proposals(
  33. self,
  34. seq_group_metadata_list: List[SequenceGroupMetadata],
  35. blocks_to_swap_in: Dict[int, int],
  36. blocks_to_swap_out: Dict[int, int],
  37. blocks_to_copy: Dict[int, List[int]],
  38. proposal_len: int,
  39. ) -> SpeculativeProposals:
  40. """Get speculative proposals given the input batch.
  41. Sequences which would exceed the max model length are skipped during
  42. speculation.
  43. """
  44. # Split speculative- and non-speculative- sequences.
  45. (
  46. proposal_lens,
  47. nonzero_proposal_len_seqs,
  48. nonzero_proposal_len_indices,
  49. ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
  50. if nonzero_proposal_len_seqs:
  51. # Speculate tokens using the draft worker for the speculative
  52. # sequences.
  53. # If sampler_transposed is true, then maybe_sampler_output's
  54. # token_ids is like [batch] format in proposal_len size list,
  55. # while if it is false, the format would be [proposal_len]
  56. # in batch size list
  57. maybe_sampler_output, transposed = self._worker.sampler_output(
  58. seq_group_metadata_list=nonzero_proposal_len_seqs,
  59. blocks_to_swap_in=blocks_to_swap_in,
  60. blocks_to_swap_out=blocks_to_swap_out,
  61. blocks_to_copy=blocks_to_copy,
  62. sample_len=proposal_len,
  63. )
  64. else:
  65. # If no sequences can be speculated, set sampler output to None.
  66. maybe_sampler_output = None
  67. transposed = False
  68. # Combine speculative- and non-speculative sequences into the same
  69. # representation.
  70. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
  71. batch_size=len(seq_group_metadata_list),
  72. proposal_len=proposal_len,
  73. maybe_sampler_output=maybe_sampler_output,
  74. proposal_lens=proposal_lens,
  75. nonzero_proposal_len_indices=nonzero_proposal_len_indices,
  76. sampler_transposed=transposed,
  77. )
  78. proposals = SpeculativeProposals(
  79. proposal_token_ids=proposal_tokens,
  80. proposal_probs=proposal_probs,
  81. proposal_lens=proposal_lens,
  82. )
  83. return proposals
  84. def _split_by_max_model_len(
  85. self,
  86. seq_group_metadata_list: List[SequenceGroupMetadata],
  87. proposal_len: int,
  88. ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
  89. """Determine which sequences would exceed the max model length."""
  90. proposal_lens: List[int] = []
  91. nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
  92. nonzero_proposal_len_indices: List[int] = []
  93. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  94. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  95. seq_len = seq_data.get_len()
  96. # Currently only proposal lens of 0 or the global batch proposal len
  97. # are supported.
  98. # If max_proposal_len is defined, then we shall no exccess this
  99. # quota for nonzero_proposal
  100. if (self.max_proposal_len is None
  101. or seq_len + proposal_len < self.max_proposal_len):
  102. proposal_lens.append(proposal_len)
  103. nonzero_proposal_len_seqs.append(seq_group_metadata)
  104. nonzero_proposal_len_indices.append(i)
  105. else:
  106. proposal_lens.append(0)
  107. return (
  108. proposal_lens,
  109. nonzero_proposal_len_seqs,
  110. nonzero_proposal_len_indices,
  111. )
  112. def _merge_outputs(
  113. self,
  114. batch_size: int,
  115. proposal_len: int,
  116. maybe_sampler_output: Optional[SamplerOutput],
  117. proposal_lens: List[int],
  118. nonzero_proposal_len_indices: List[int],
  119. sampler_transposed: bool,
  120. ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
  121. """After speculations are produced, merge the speculation results with
  122. the skipped sequences.
  123. """
  124. if maybe_sampler_output is None:
  125. # If no speculative tokens, the sampler output will be None.
  126. # In this case we return empty proposals.
  127. proposal_tokens = torch.full(
  128. size=(
  129. batch_size,
  130. proposal_len,
  131. ),
  132. fill_value=-1,
  133. dtype=torch.long,
  134. device=self._device,
  135. )
  136. proposal_probs = torch.zeros(
  137. batch_size,
  138. proposal_len,
  139. self._vocab_size,
  140. dtype=torch.float32,
  141. device=self._device,
  142. )
  143. proposal_lens_tensor = torch.zeros(len(proposal_lens),
  144. dtype=torch.long,
  145. device=self._device)
  146. return proposal_tokens, proposal_probs, proposal_lens_tensor
  147. sampler_output = maybe_sampler_output
  148. proposal_tokens, proposal_probs = sampler_output_to_torch(
  149. sampler_output, sampler_transposed)
  150. # Now, reformat the output GPU tensors such that each sequence has
  151. # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
  152. entire_proposal_tokens = torch.full(
  153. size=(batch_size, *proposal_tokens.shape[1:]),
  154. fill_value=-1,
  155. dtype=torch.long,
  156. device=self._device,
  157. )
  158. entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
  159. entire_proposal_probs = torch.zeros(
  160. batch_size,
  161. *proposal_probs.shape[1:],
  162. dtype=torch.float32,
  163. device=self._device,
  164. )
  165. entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
  166. proposal_tokens, proposal_probs = (
  167. entire_proposal_tokens,
  168. entire_proposal_probs,
  169. )
  170. proposal_lens_tensor = torch.zeros(batch_size,
  171. dtype=torch.long,
  172. device=self._device)
  173. proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
  174. return proposal_tokens, proposal_probs, proposal_lens_tensor