multi_step_worker.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. from typing import List, Dict, Optional, Tuple
  2. import copy
  3. import torch
  4. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  5. from aphrodite.task_handler.worker import Worker
  6. from aphrodite.spec_decode.interfaces import (
  7. SpeculativeProposals,
  8. SpeculativeProposer,
  9. )
  10. from aphrodite.spec_decode.util import sampler_output_to_torch
  11. class MultiStepWorker(Worker):
  12. """The MultiStepWorker is equivalent to a Worker except that it allows
  13. multiple forward passes in a single call, assuming the scheduler has
  14. allocated enough space to store the additional KV. This reduces overhead
  15. by invoking the scheduler less.
  16. The MultiStepWorker does not support cache swap operations, or beam search.
  17. Cache swap operations do not require large modifications. On the other hand,
  18. beam search requires memory allocations during sequence forks and thus
  19. requires more thought for MultiStepWorker support.
  20. """
  21. def __init__(self, *args, **kwargs):
  22. super().__init__(*args, **kwargs)
  23. self._proposer: Optional[DraftModelTop1Proposer] = None
  24. def init_device(self):
  25. super().init_device()
  26. self._proposer = DraftModelTop1Proposer(
  27. self,
  28. self.device,
  29. self.max_model_len,
  30. self.vocab_size,
  31. )
  32. @torch.inference_mode()
  33. def execute_model_multi_step(
  34. self,
  35. seq_group_metadata_list: List[SequenceGroupMetadata],
  36. blocks_to_swap_in: Dict[int, int],
  37. blocks_to_swap_out: Dict[int, int],
  38. blocks_to_copy: Dict[int, List[int]],
  39. num_steps: int,
  40. ) -> List[SamplerOutput]:
  41. """Run the model forward pass num_steps times. Returns the list of
  42. sampler output, one per model forward pass.
  43. """
  44. self._raise_if_unsupported(
  45. seq_group_metadata_list,
  46. blocks_to_swap_in,
  47. blocks_to_swap_out,
  48. blocks_to_copy,
  49. )
  50. # Shallow copy input data so modifications (such as appending tokens)
  51. # do not cause side-effects.
  52. copied_seq_group_metadata_list = self._shallow_copy_inputs(
  53. seq_group_metadata_list)
  54. # Assert enough KV space for num_steps tokens per sequence.
  55. self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
  56. # Run model num_steps times.
  57. model_outputs = []
  58. for _ in range(num_steps):
  59. model_output = super().execute_model(
  60. seq_group_metadata_list=copied_seq_group_metadata_list,
  61. blocks_to_swap_in=blocks_to_swap_in,
  62. blocks_to_swap_out=blocks_to_swap_out,
  63. blocks_to_copy=blocks_to_copy,
  64. )
  65. self._append_new_tokens(model_output,
  66. copied_seq_group_metadata_list)
  67. model_outputs.append(model_output)
  68. return model_outputs
  69. def get_spec_proposals(
  70. self,
  71. seq_group_metadata_list: List[SequenceGroupMetadata],
  72. blocks_to_swap_in: Dict[int, int],
  73. blocks_to_swap_out: Dict[int, int],
  74. blocks_to_copy: Dict[int, List[int]],
  75. max_proposal_len: int,
  76. ) -> SpeculativeProposals:
  77. """Produce speculations given an input batch of sequences. The number of
  78. speculative tokens per sequence is determined by max_proposal_len.
  79. """
  80. return self._proposer.get_proposals(
  81. seq_group_metadata_list,
  82. blocks_to_swap_in,
  83. blocks_to_swap_out,
  84. blocks_to_copy,
  85. max_proposal_len,
  86. )
  87. def _append_new_tokens(
  88. self,
  89. model_output: SamplerOutput,
  90. seq_group_metadata_list: SequenceGroupMetadata,
  91. ) -> None:
  92. """Given model output from a single run, append the tokens to the
  93. sequences. This is normally done outside of the worker, but it is
  94. required if the worker is to perform multiple forward passes.
  95. """
  96. for seq_group_metadata, sequence_group_outputs in zip(
  97. seq_group_metadata_list, model_output):
  98. seq_group_metadata.is_prompt = False
  99. for seq_output in sequence_group_outputs.samples:
  100. # NOTE: Beam search is not supported, so we can assume that
  101. # parent_seq_id == seq_id.
  102. seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
  103. token_id = seq_output.output_token
  104. token_logprob = seq_output.logprobs[token_id]
  105. seq.append_token_id(token_id, token_logprob.logprob)
  106. def _shallow_copy_inputs(
  107. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  108. ) -> List[SequenceGroupMetadata]:
  109. """Copy input data structures to remove side-effects when input data
  110. structures are shared with other modules.
  111. Helpful when the Aphrodite scheduler runs in the same process as the
  112. worker. The alternative is deep-copying (or other form of deep copy);
  113. this has performance downsides.
  114. """
  115. # Shallow-copy the list of SequenceGroupMetadata. This allows us to
  116. # append tokens and change is_prompt without external side-effects.
  117. new_seq_group_metadata_list = []
  118. for old_seq_group_metadata in seq_group_metadata_list:
  119. # We must shallow-copy seq_group_metadata as is_prompt could change.
  120. seq_group_metadata = copy.copy(old_seq_group_metadata)
  121. new_seq_group_metadata_list.append(seq_group_metadata)
  122. # We must shallow-copy seq_data as we will append token ids
  123. new_seq_data = {}
  124. for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
  125. new_seq_data[seq_id] = copy.copy(old_seq_data)
  126. new_seq_data[
  127. seq_id].output_token_ids = old_seq_data.output_token_ids[:]
  128. seq_group_metadata.seq_data = new_seq_data
  129. return new_seq_group_metadata_list
  130. def _assert_enough_kv_space(
  131. self,
  132. seq_group_metadata_list: List[SequenceGroupMetadata],
  133. num_steps: int,
  134. ) -> None:
  135. """Assert there are enough physical blocks per sequence to store the
  136. current KV plus additional KV from num_steps tokens.
  137. """
  138. assert self.model_runner.block_size is not None
  139. for seq_group_metadata in seq_group_metadata_list:
  140. # Only one seq_id is guaranteed because there is no beam search.
  141. seq_id = list(seq_group_metadata.seq_data.keys())[0]
  142. seq = seq_group_metadata.seq_data[seq_id]
  143. # After num_steps, the seq len will be the current seq len
  144. # plus one token per step.
  145. final_seq_len = seq.get_len() + num_steps
  146. # We will have final_seq_len - 1 KV because Aphrodite saves KV for a
  147. # token in the iteration after the token was generated.
  148. required_num_kv_slots = final_seq_len - 1
  149. # The allocated number of kv slots is the number of allocated blocks
  150. # times the number of slots of block.
  151. number_physical_blocks = len(
  152. seq_group_metadata.block_tables[seq_id])
  153. allocated_kv_slots = (number_physical_blocks *
  154. self.model_runner.block_size)
  155. if required_num_kv_slots > allocated_kv_slots:
  156. request_id = seq_group_metadata.request_id
  157. raise ValueError(
  158. "The worker attempted to run "
  159. f"{num_steps} times but found insufficient KV space for "
  160. f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
  161. f"{required_num_kv_slots=}).")
  162. def _raise_if_unsupported(
  163. self,
  164. seq_group_metadata_list: List[SequenceGroupMetadata],
  165. blocks_to_swap_in: Dict[int, int],
  166. blocks_to_swap_out: Dict[int, int],
  167. blocks_to_copy: Dict[int, List[int]],
  168. ) -> None:
  169. """MultiStepWorker does not yet implement support for cache swap
  170. operations or beam search.
  171. """
  172. if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
  173. raise NotImplementedError(
  174. "MultiStepWorker does not support cache operations")
  175. if any(
  176. len(seq_group_metadata.seq_data.keys()) != 1
  177. for seq_group_metadata in seq_group_metadata_list):
  178. raise NotImplementedError(
  179. "MultiStepWorker does not support beam search.")
  180. class DraftModelTop1Proposer(SpeculativeProposer):
  181. """Helper class which separates out sequences which would exceed the max
  182. model length when speculated upon.
  183. This allows combinations of models such as JackFram/llama-68m draft with
  184. meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
  185. 2048 while Llama2-13b has max_position_embeddings of 4096.
  186. We treat the sequences which exceed the proposal draft model length as
  187. "non-spec sequences". Essentially they skip the draft model and go through
  188. normal decoding in the target model.
  189. Currently, only proposal_lens of 0 and k are supported, where k is a global
  190. batch proposal length. In the future Aphrodite should support per-sequence
  191. proposal lengths.
  192. """
  193. def __init__(
  194. self,
  195. draft_worker: MultiStepWorker,
  196. device: str,
  197. max_model_len: int,
  198. vocab_size: int,
  199. ):
  200. self._draft_worker = draft_worker
  201. self._device = device
  202. self._max_model_len = max_model_len
  203. self._vocab_size = vocab_size
  204. def get_proposals(
  205. self,
  206. seq_group_metadata_list: List[SequenceGroupMetadata],
  207. blocks_to_swap_in: Dict[int, int],
  208. blocks_to_swap_out: Dict[int, int],
  209. blocks_to_copy: Dict[int, List[int]],
  210. max_proposal_len: int,
  211. ) -> SpeculativeProposals:
  212. """Get speculative proposals given the input batch.
  213. Sequences which would exceed the max model length are skipped during
  214. speculation.
  215. """
  216. # Split speculative- and non-speculative- sequences.
  217. (
  218. proposal_lens,
  219. nonzero_proposal_len_seqs,
  220. nonzero_proposal_len_indices,
  221. ) = self._split_by_max_model_len(seq_group_metadata_list,
  222. max_proposal_len)
  223. if nonzero_proposal_len_seqs:
  224. # Speculate tokens using the draft worker for the speculative
  225. # sequences.
  226. maybe_sampler_output = self._draft_worker.execute_model_multi_step(
  227. seq_group_metadata_list=nonzero_proposal_len_seqs,
  228. blocks_to_swap_in=blocks_to_swap_in,
  229. blocks_to_swap_out=blocks_to_swap_out,
  230. blocks_to_copy=blocks_to_copy,
  231. num_steps=max_proposal_len,
  232. )
  233. else:
  234. # If no sequences can be speculated, set sampler output to None.
  235. maybe_sampler_output = None
  236. # Combine speculative- and non-speculative sequences into the same
  237. # representation.
  238. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
  239. batch_size=len(seq_group_metadata_list),
  240. max_proposal_len=max_proposal_len,
  241. maybe_sampler_output=maybe_sampler_output,
  242. proposal_lens=proposal_lens,
  243. nonzero_proposal_len_indices=nonzero_proposal_len_indices,
  244. )
  245. proposals = SpeculativeProposals(
  246. proposal_token_ids=proposal_tokens,
  247. proposal_probs=proposal_probs,
  248. proposal_lens=proposal_lens,
  249. )
  250. return proposals
  251. def _split_by_max_model_len(
  252. self,
  253. seq_group_metadata_list: List[SequenceGroupMetadata],
  254. max_proposal_len: int,
  255. ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
  256. """Determine which sequences would exceed the max model length."""
  257. proposal_lens: List[int] = []
  258. nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
  259. nonzero_proposal_len_indices: List[int] = []
  260. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  261. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  262. seq_len = seq_data.get_len()
  263. # Currently only proposal lens of 0 or the global batch proposal len
  264. # are supported.
  265. if seq_len + max_proposal_len < self._max_model_len:
  266. proposal_lens.append(max_proposal_len)
  267. nonzero_proposal_len_seqs.append(seq_group_metadata)
  268. nonzero_proposal_len_indices.append(i)
  269. else:
  270. proposal_lens.append(0)
  271. return (
  272. proposal_lens,
  273. nonzero_proposal_len_seqs,
  274. nonzero_proposal_len_indices,
  275. )
  276. def _merge_outputs(
  277. self,
  278. batch_size: int,
  279. max_proposal_len: int,
  280. maybe_sampler_output: Optional[SamplerOutput],
  281. proposal_lens: List[int],
  282. nonzero_proposal_len_indices: List[int],
  283. ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
  284. """After speculations are produced, merge the speculation results with
  285. the skipped sequences.
  286. """
  287. if maybe_sampler_output is None:
  288. # If no speculative tokens, the sampler output will be None.
  289. # In this case we return empty tensors.
  290. proposal_tokens = torch.zeros(0,
  291. max_proposal_len,
  292. dtype=torch.long,
  293. device=self._device)
  294. proposal_probs = torch.zeros(
  295. 0,
  296. max_proposal_len,
  297. self._vocab_size,
  298. dtype=torch.float32,
  299. device=self._device,
  300. )
  301. proposal_lens = torch.zeros(len(proposal_lens),
  302. dtype=torch.long,
  303. device=self._device)
  304. return proposal_tokens, proposal_probs, proposal_lens
  305. sampler_output = maybe_sampler_output
  306. proposal_tokens, proposal_probs = sampler_output_to_torch(
  307. sampler_output)
  308. # Now, reformat the output GPU tensors such that each sequence has
  309. # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
  310. entire_proposal_tokens = torch.full(
  311. size=(batch_size, *proposal_tokens.shape[1:]),
  312. fill_value=-1,
  313. dtype=torch.long,
  314. device=self._device,
  315. )
  316. entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
  317. entire_proposal_probs = torch.zeros(
  318. batch_size,
  319. *proposal_probs.shape[1:],
  320. dtype=torch.float32,
  321. device=self._device,
  322. )
  323. entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
  324. proposal_tokens, proposal_probs = (
  325. entire_proposal_tokens,
  326. entire_proposal_probs,
  327. )
  328. proposal_lens = torch.zeros(batch_size,
  329. dtype=torch.long,
  330. device=self._device)
  331. proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
  332. return proposal_tokens, proposal_probs, proposal_lens