multi_step_model_runner.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. import dataclasses
  2. import functools
  3. from dataclasses import dataclass, field
  4. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
  5. Union)
  6. import torch
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  8. IntermediateTensors, Logprob,
  9. SequenceGroupMetadata, SequenceOutput)
  10. from aphrodite.common.utils import PyObjectCache
  11. from aphrodite.distributed import get_pp_group
  12. from aphrodite.modeling.layers.sampler import (PromptLogprobs, SampleLogprobs,
  13. SamplerOutput, SamplingMetadata,
  14. get_logprobs,
  15. get_pythonized_sample_results)
  16. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  17. from aphrodite.worker.model_runner import (
  18. GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
  19. from aphrodite.worker.model_runner_base import (
  20. BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
  21. _init_frozen_model_input_from_tensor_dict,
  22. _init_sampling_metadata_from_tensor_dict)
  23. if TYPE_CHECKING:
  24. from aphrodite.attention.backends.abstract import AttentionBackend
  25. MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer", "rocm-flash-attn"]
  26. def seq_output_builder():
  27. return SequenceOutput(
  28. 0, 0,
  29. {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
  30. def completion_seq_group_output_builder():
  31. return CompletionSequenceGroupOutput([], None)
  32. # Used by pythonization to reduce python object allocations
  33. class PythonizationCache:
  34. def __init__(self):
  35. self.cached_seq_output = PyObjectCache(seq_output_builder)
  36. self.cached_completion_seq_group_output = PyObjectCache(
  37. completion_seq_group_output_builder)
  38. def reset(self):
  39. self.cached_seq_output.reset()
  40. self.cached_completion_seq_group_output.reset()
  41. @dataclass
  42. class ModelOutput:
  43. """The output of a single model forward pass.
  44. The sampler_output_ready_event is set when the tensors in
  45. sampler_output are ready (the model+sampler forward pass has
  46. completed). We use the event to synchronize the GPU->CPU transfer,
  47. which we want to only run when the data has been written to the
  48. GPU tensors. Until the event is ready, the tensors in sampler_output
  49. will have garbage data.
  50. There are two scenarios:
  51. 1. The output tensors are ready and we can pythonize them immediately.
  52. 2. The output tensors are not ready and we need to wait for the event to be
  53. ready.
  54. """
  55. sampler_output: SamplerOutput
  56. sampler_output_ready_event: torch.cuda.Event
  57. sampled_token_ids: Optional[torch.Tensor] = None
  58. pythonized: bool = False
  59. # On-device tensor containing the logprobs of each token.
  60. logprobs: Optional["torch.Tensor"] = None
  61. pythonization_cache: Optional[PythonizationCache] = None
  62. def pythonize(
  63. self,
  64. input_metadata: "StatefulModelInput",
  65. copy_stream: torch.cuda.Stream,
  66. pinned_sampled_token_buffer: torch.Tensor,
  67. ) -> None:
  68. """Pythonize the output. Blocking."""
  69. if not self.pythonized:
  70. self._pythonize_sampler_output(
  71. input_metadata, copy_stream, pinned_sampled_token_buffer, True
  72. )
  73. self.pythonized = True
  74. def maybe_pythonize(
  75. self,
  76. input_metadata: "StatefulModelInput",
  77. copy_stream: torch.cuda.Stream,
  78. pinned_sampled_token_buffer: torch.Tensor,
  79. ) -> None:
  80. """Pythonize the output if ready, else return None. Non-blocking."""
  81. if not self.pythonized:
  82. self.pythonized = self._pythonize_sampler_output(
  83. input_metadata, copy_stream, pinned_sampled_token_buffer, False
  84. )
  85. def _pythonize_sampler_output(
  86. self,
  87. input_metadata: "StatefulModelInput",
  88. copy_stream: torch.cuda.Stream,
  89. pinned_sampled_token_buffer: torch.Tensor,
  90. blocking: bool,
  91. ) -> bool:
  92. """
  93. If blocking is set, will block until the forward pass for the output is
  94. ready and pythonize the output. Upon completing Pythonization, erases
  95. self.logprobs (note that a non-blocking call that is performed when
  96. the sampler output is not yet ready, will not erase self.logprobs.)
  97. """
  98. assert self.sampled_token_ids is not None
  99. if not blocking and not self.sampler_output_ready_event.query():
  100. return False
  101. if blocking:
  102. self.sampler_output_ready_event.synchronize()
  103. with torch.cuda.stream(copy_stream):
  104. _pythonize_sampler_output(
  105. input_metadata,
  106. self.sampler_output,
  107. pinned_sampled_token_buffer,
  108. self.sampled_token_ids, self.logprobs,
  109. self.pythonization_cache)
  110. # Erase the logprobs GPU-side tensor.
  111. # Note that although _pythonize_sampler_output() runs in its
  112. # own CUDA stream, nonetheless _pythonize_sampler_output()
  113. # cannot return until Pythonization is complete; therefore
  114. # we know that by the time the CPU reaches this point,
  115. # `self.logprobs` is no longer needed.
  116. self.logprobs = None
  117. return True
  118. @dataclass(frozen=False)
  119. class StatefulModelInput(BroadcastableModelInput):
  120. # actual frozen model input dataclass passed to _base_model_runner
  121. frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None
  122. # list of model outputs for each step, may not be all pythonized
  123. cached_outputs: List[ModelOutput] = field(default_factory=list)
  124. # used to pass sampled token ids from the last step to the current step for
  125. # TP workers. Used to append to end of outputs and used by advance_step
  126. last_sampled_token_ids: Optional[torch.Tensor] = None
  127. current_step: int = 0
  128. is_multi_step: bool = True
  129. is_last_step: bool = False
  130. is_first_multi_step: bool = False
  131. # ping-pong data structures for multi-step to wait on the previous step
  132. step_cuda_events: List[torch.cuda.Event] = field(
  133. default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2
  134. )
  135. num_seqs: int = -1
  136. num_queries: int = -1
  137. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  138. assert self.frozen_model_input is not None
  139. tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
  140. new_tensor_dict = {
  141. "last_sampled_token_ids": self.last_sampled_token_ids,
  142. "current_step": self.current_step,
  143. "is_multi_step": self.is_multi_step,
  144. "is_last_step": self.is_last_step,
  145. "is_first_multi_step": self.is_first_multi_step,
  146. "num_seqs": self.num_seqs,
  147. "num_queries": self.num_queries,
  148. }
  149. tensor_dict.update(new_tensor_dict)
  150. return tensor_dict
  151. @classmethod
  152. def from_broadcasted_tensor_dict(
  153. cls,
  154. tensor_dict: Dict[str, Any],
  155. attn_backend: Optional["AttentionBackend"] = None,
  156. ) -> "StatefulModelInput":
  157. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  158. if attn_backend is not None:
  159. tensor_dict = _init_attn_metadata_from_tensor_dict(
  160. attn_backend, tensor_dict
  161. )
  162. tensor_dict = _init_frozen_model_input_from_tensor_dict(
  163. ModelInputForGPUWithSamplingMetadata, tensor_dict
  164. )
  165. return cls(**tensor_dict)
  166. def record_step_event(self, current_stream: torch.cuda.Stream):
  167. # record the event for the current step so that the next step can sync
  168. # on it. We modulo by 2 to keep the events in a circular buffer and
  169. # support any attn backends that may be supported in the future. ie
  170. # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
  171. self.step_cuda_events[self.current_step & 1] = torch.cuda.Event(
  172. blocking=True
  173. )
  174. self.step_cuda_events[self.current_step & 1].record(current_stream)
  175. def wait_previous_step(self):
  176. # These cuda events are an explicit synchronization to ensure that
  177. # advance_step() (for other attn backends that may be supported in the
  178. # future) do not clobber any data structures that is also used by any
  179. # enqueued forwards steps. For distributed case, only a single event is
  180. # needed, but for single GPU case, since we can let the CPU run much
  181. # further ahead, two events allow us to overlap the advance_step with
  182. # the previous forward (ie using two DecodeWrappers for flashinfer
  183. # backend)
  184. self.step_cuda_events[(self.current_step + 1) & 1].wait()
  185. def add_sampler_output(
  186. self,
  187. sampler_output: SamplerOutput,
  188. sampled_token_ids: Optional[torch.Tensor] = None,
  189. ):
  190. self.cached_outputs.append(
  191. ModelOutput(
  192. sampler_output=sampler_output,
  193. sampler_output_ready_event=None,
  194. sampled_token_ids=sampled_token_ids,
  195. pythonized=False,
  196. )
  197. )
  198. # MutableModelInputForGPUWithMultiStepMetadata is not subclass of
  199. # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
  200. # metadata
  201. # mypy: disable-error-code=type-var
  202. class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
  203. # mypy: enable-error-code=type-var
  204. def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
  205. super().__init__(*args, **kwargs)
  206. # uses the base model runner to execute the model and wraps it with
  207. # multi-step logic
  208. self._base_model_runner: GPUModelRunnerBase = base_model_runner
  209. self.is_multi_step = self.scheduler_config.is_multi_step
  210. self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
  211. self.pythonization_cache = PythonizationCache()
  212. @functools.cached_property
  213. def _copy_stream(self):
  214. # used to copy tensors from GPU to CPU asynchronously
  215. return torch.cuda.Stream()
  216. def make_model_input_from_broadcasted_tensor_dict(
  217. self, tensor_dict: Dict[str, Any]
  218. ) -> StatefulModelInput:
  219. model_input = StatefulModelInput.from_broadcasted_tensor_dict(
  220. tensor_dict,
  221. attn_backend=self.attn_backend,
  222. )
  223. return model_input
  224. def prepare_model_input(
  225. self,
  226. seq_group_metadata_list: List[SequenceGroupMetadata],
  227. virtual_engine: int = 0,
  228. finished_requests_ids: Optional[List[str]] = None,
  229. ) -> StatefulModelInput:
  230. frozen_model_input = self._base_model_runner.prepare_model_input(
  231. seq_group_metadata_list, virtual_engine, finished_requests_ids
  232. )
  233. model_input = StatefulModelInput(
  234. frozen_model_input=frozen_model_input,
  235. num_seqs=len(frozen_model_input.seq_lens),
  236. num_queries=len(frozen_model_input.query_lens),
  237. )
  238. return model_input
  239. def _async_process_outputs(self, model_input: StatefulModelInput,
  240. output_proc_callback: Callable):
  241. # Proceed with pythonization and output_proc in order.
  242. # Stop on the first one that fails to pythonize
  243. output_proc_callback()
  244. cont = True
  245. for model_output in model_input.cached_outputs:
  246. if not model_output.pythonized:
  247. model_output.maybe_pythonize(model_input, self._copy_stream,
  248. self.pinned_sampled_token_ids)
  249. if model_output.pythonized:
  250. ctx = output_proc_callback.keywords["ctx"]
  251. ctx.append_output(
  252. outputs=[model_output.sampler_output],
  253. seq_group_metadata_list=ctx.seq_group_metadata_list,
  254. scheduler_outputs=ctx.scheduler_outputs,
  255. is_async=False,
  256. is_last_step=False)
  257. output_proc_callback()
  258. else:
  259. cont = False
  260. if not cont:
  261. break
  262. def _final_process_outputs(self, model_input: StatefulModelInput,
  263. output_proc_callback: Optional[Callable]):
  264. assert model_input.frozen_model_input is not None
  265. has_async_callback = output_proc_callback is not None
  266. outputs = []
  267. for output_id in range(len(model_input.cached_outputs)):
  268. output = model_input.cached_outputs[output_id]
  269. is_last_step = output_id == len(model_input.cached_outputs) - 1
  270. # For non-async case:
  271. # -- We simply add the outputs
  272. # For async case:
  273. # -- Invoke callback, pythonize, add to callback queue and repeat
  274. # -- For last output, just add to callback queue
  275. if has_async_callback:
  276. assert output_proc_callback is not None
  277. # Invoke callback before pythonize (to overlap with GPU)
  278. output_proc_callback()
  279. # Pythonize
  280. if not output.pythonized:
  281. output.pythonize(model_input, self._copy_stream,
  282. self.pinned_sampled_token_ids)
  283. # For non last step, add to callback queue to chain
  284. # callbacks=>pythonize pairs (for GPU overlap)
  285. if not is_last_step:
  286. ctx = output_proc_callback.keywords[ # type: ignore
  287. "ctx"] # type: ignore
  288. ctx.append_output(
  289. outputs=[output.sampler_output],
  290. seq_group_metadata_list=ctx.
  291. seq_group_metadata_list,
  292. scheduler_outputs=ctx.scheduler_outputs,
  293. is_async=False,
  294. is_last_step=False)
  295. else:
  296. outputs.append(output.sampler_output)
  297. else:
  298. output.pythonize(model_input, self._copy_stream,
  299. self.pinned_sampled_token_ids)
  300. outputs.append(output.sampler_output)
  301. return outputs
  302. @torch.inference_mode()
  303. def execute_model(
  304. self,
  305. model_input: StatefulModelInput,
  306. kv_caches: List[torch.Tensor],
  307. intermediate_tensors: Optional[IntermediateTensors] = None,
  308. num_steps: int = 1,
  309. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  310. """
  311. Execute the model for a single step and update multi-step
  312. metadata
  313. """
  314. assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
  315. frozen_model_input = model_input.frozen_model_input
  316. assert frozen_model_input is not None
  317. # path for warm up runs
  318. if not model_input.is_multi_step:
  319. return self._base_model_runner.execute_model(
  320. frozen_model_input, kv_caches, intermediate_tensors, num_steps
  321. )
  322. # make sure we skip the sampler on the lask rank and only pythonize
  323. # if CPU is ahead.
  324. if self.is_driver_worker and get_pp_group().is_last_rank:
  325. if self.pinned_sampled_token_ids is None:
  326. self.pinned_sampled_token_ids = torch.zeros(
  327. (self.scheduler_config.max_num_seqs, 1),
  328. dtype=torch.long,
  329. device="cpu",
  330. pin_memory=True,
  331. )
  332. self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
  333. True
  334. )
  335. if frozen_model_input.sampling_metadata:
  336. frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
  337. True
  338. )
  339. # some pre-execute model logic for multi-step:
  340. # - if it's the first step, we need to reset the sampling tensors
  341. # - if it's not the first step, we need to advance the step using the
  342. # appended sampler output from last iteration
  343. # - also maybe pythonize if CPU is ahead of GPU
  344. current_stream = torch.cuda.current_stream()
  345. if not model_input.is_first_multi_step:
  346. # Explicitly block on the previous step's forward to make sure we
  347. # don't clobber any GPU tensors still in use.
  348. # This is not needed for flashattn backend, but for other attn
  349. # backends such as flashinfer that performs extra CPU operations on
  350. # input metadata we may need to synchronize any CPU operations that
  351. # might clobber enqueued forwards. (prevents CPU from running too
  352. # far ahead if needed)
  353. model_input.wait_previous_step()
  354. model_input = self._advance_step(
  355. model_input, model_input.cached_outputs[-1].sampler_output
  356. )
  357. output_proc_callback = None
  358. if frozen_model_input.async_callback is not None:
  359. output_proc_callback = frozen_model_input.async_callback
  360. assert output_proc_callback is not None
  361. async_callback = functools.partial(
  362. self._async_process_outputs,
  363. model_input=model_input,
  364. output_proc_callback=output_proc_callback)
  365. frozen_model_input = dataclasses.replace( # type: ignore
  366. model_input.frozen_model_input,
  367. async_callback=async_callback)
  368. assert frozen_model_input is not None
  369. # Execute the model
  370. output = self._base_model_runner.execute_model(
  371. frozen_model_input, kv_caches, intermediate_tensors, num_steps=1
  372. )
  373. # record the event for the current step so that the next step can sync
  374. model_input.record_step_event(current_stream)
  375. if get_pp_group().is_last_rank and self.is_driver_worker:
  376. assert (
  377. len(output) == 1
  378. ), "MultiStepModelRunner requires single-step base_models"
  379. # event for the pythonization so that we only pythonize if the
  380. # tensors are ready. May be able to be combined with the step event
  381. output_ready_event = torch.cuda.Event()
  382. output_ready_event.record(current_stream)
  383. if self.parallel_config.pipeline_parallel_size > 1:
  384. output[0].sampled_token_ids_cpu = output[
  385. 0
  386. ].sampled_token_ids.cpu()
  387. model_input.cached_outputs.append(
  388. ModelOutput(
  389. output[0],
  390. output_ready_event,
  391. output[0].sampled_token_ids, False,
  392. output[0].logprobs, self.pythonization_cache))
  393. # These GPU tensors are not required by multi-step;
  394. # erase them to ensure they are not pythonized or
  395. # transferred to CPU
  396. output[0].sampled_token_ids = None
  397. output[0].sampled_token_probs = None
  398. output[0].logprobs = None
  399. # Pythonize the output if CPU is ahead and the previous step is
  400. # ready.
  401. if frozen_model_input.async_callback is None:
  402. for model_output in model_input.cached_outputs:
  403. model_output.maybe_pythonize(model_input,
  404. self._copy_stream,
  405. self.pinned_sampled_token_ids)
  406. model_input.current_step += 1
  407. if not get_pp_group().is_last_rank:
  408. # Should be IntermediateTensors
  409. assert isinstance(output, IntermediateTensors)
  410. return output
  411. if not self.is_driver_worker:
  412. return []
  413. # Pythonize the output and block if needed since it is the last step
  414. if model_input.is_last_step:
  415. outputs = self._final_process_outputs(model_input,
  416. output_proc_callback)
  417. self.pythonization_cache.reset()
  418. return outputs
  419. # should be [SamplerOutput]
  420. return output
  421. def _update_sampling_metadata(
  422. self, sampling_metadata, num_seqs, num_queries
  423. ):
  424. assert sampling_metadata.num_prompts == 0
  425. assert len(sampling_metadata.seq_groups) == num_queries
  426. assert sampling_metadata.selected_token_indices.shape == (num_queries,)
  427. # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
  428. # Verify that all sequences are decodes
  429. for i in range(num_queries):
  430. seq_group = sampling_metadata.seq_groups[i]
  431. assert seq_group.is_prompt is False # No prompt
  432. assert seq_group.prompt_logprob_indices == [] # No prompt
  433. assert seq_group.sample_indices == [i] # Simple
  434. assert seq_group.seq_len is None # Decode
  435. assert seq_group.query_len is None # Decode
  436. def _advance_step(self, model_input: StatefulModelInput,
  437. out: SamplerOutput) -> StatefulModelInput:
  438. if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
  439. raise ValueError(
  440. f"Multi-step not supported for attention backend: "
  441. f"{self.attn_backend.get_name()}. Set "
  442. "APHRODITE_ATTENTION_BACKEND to a value from "
  443. f"{MULTI_STEP_ATTENTION_BACKENDS}.")
  444. sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
  445. num_seqs = model_input.num_seqs
  446. num_queries = model_input.num_queries
  447. frozen_model_input = model_input.frozen_model_input
  448. assert frozen_model_input is not None
  449. attn_metadata = frozen_model_input.attn_metadata
  450. assert attn_metadata is not None
  451. attn_metadata.advance_step(
  452. frozen_model_input,
  453. sampled_token_ids,
  454. self.block_size,
  455. num_seqs,
  456. num_queries,
  457. )
  458. return model_input
  459. def load_model(self) -> None:
  460. return self._base_model_runner.load_model()
  461. def save_sharded_state(
  462. self,
  463. path: str,
  464. pattern: Optional[str] = None,
  465. max_size: Optional[int] = None,
  466. ) -> None:
  467. return self._base_model_runner.save_sharded_state(
  468. path, pattern, max_size
  469. )
  470. def save_tensorized_model(
  471. self, tensorizer_config: TensorizerConfig
  472. ) -> None:
  473. return self._base_model_runner.save_tensorized_model(tensorizer_config)
  474. def profile_run(self) -> None:
  475. return self._base_model_runner.profile_run()
  476. def remove_all_loras(self):
  477. return self._base_model_runner.remove_all_loras()
  478. def capture_model(self, kv_caches: List[List]) -> None:
  479. return self._base_model_runner.capture_model(kv_caches)
  480. @property
  481. def vocab_size(self) -> int:
  482. return self._base_model_runner.vocab_size
  483. DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
  484. Optional[List[SampleLogprobs]]]
  485. def deferred_pythonize_logprobs(
  486. output: SamplerOutput,
  487. sampling_metadata: SamplingMetadata,
  488. logprobs_tensor: Optional[torch.Tensor],
  489. ) -> DeferredLogprobsReturnType:
  490. """Perform deferred logprob Pythonization.
  491. 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
  492. 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
  493. utilizing the Pythonized sampler result computed in step 1.
  494. These deferred computations are not required for single-step scheduling
  495. or the `profile_run()` phase of multi-step scheduling.
  496. Args:
  497. output: sampler output (under deferred Pythonization)
  498. sampling_metadata
  499. Returns:
  500. prompt_logprobs (CPU), sample_logprobs (CPU)
  501. """
  502. # - Deferred pythonization of sample result
  503. sampler_result = get_pythonized_sample_results(
  504. output.deferred_sample_results_args)
  505. # - Erase the GPU-side deferred sample_result
  506. # computation args to ensure it is never
  507. # pythonized or transferred to CPU
  508. output.deferred_sample_results_args = None
  509. # - Deferred pythonization of logprobs
  510. (
  511. prompt_logprobs,
  512. sample_logprobs,
  513. ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
  514. assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
  515. assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
  516. return prompt_logprobs, sample_logprobs
  517. def _pythonize_sampler_output(
  518. model_input: StatefulModelInput,
  519. output: SamplerOutput,
  520. pinned_sampled_token_buffer: torch.Tensor,
  521. sampled_token_ids: torch.Tensor,
  522. logprobs_tensor: Optional[torch.Tensor],
  523. cache: Optional[PythonizationCache],
  524. ) -> None:
  525. """ This function is only called when the output tensors are ready.
  526. See :class:`ModelOutput`.
  527. Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
  528. adding a Pythonized output data structure
  529. (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
  530. Args:
  531. model_input
  532. output: sampler output
  533. pinned_sampled_token_token_buffer: CPU-side pinned memory
  534. (receives copy of
  535. GPU-side token buffer.)
  536. sampled_token_ids: GPU-side token buffer
  537. logprobs_tensor: GPU-side tensor containing
  538. logprobs computed during sampling
  539. """
  540. assert model_input.frozen_model_input is not None
  541. frozen_model_input = model_input.frozen_model_input
  542. assert frozen_model_input.sampling_metadata is not None
  543. sampling_metadata = frozen_model_input.sampling_metadata
  544. # samples generation should have been skipped
  545. assert not output.outputs
  546. pinned_buffer = pinned_sampled_token_buffer[: model_input.num_queries]
  547. # We guarantee output tensors are ready, so it is safe to
  548. # pythonize the sampler output & obtain CPU-side logprobs.
  549. #
  550. # However we should check whether logprobs pythonization may
  551. # be skipped entirely, i.e. because no logprobs were requested
  552. # or pythonization was not deferred. To that end,
  553. #
  554. # * `prompt_logprobs_are_requested_for_prefill` signals that
  555. # there are *any* prefill-phase requests which specify that
  556. # prompt logprobs should be returned.
  557. #
  558. # * `any_logprobs_are_requested` signals that there are any
  559. # requests which (1) specify that sample logprobs should be
  560. # returned, or (2) are in the prefill phase AND specify that
  561. # prompt logprobs should be returned.
  562. #
  563. # Later on, these flags cause adjustments to the pythonization
  564. # process to accommodate logprobs.
  565. seq_groups = sampling_metadata.seq_groups
  566. prompt_logprobs_are_requested_for_prefill = any([
  567. sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
  568. for sg in seq_groups
  569. ])
  570. any_logprobs_are_requested = (
  571. prompt_logprobs_are_requested_for_prefill
  572. or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
  573. if prompt_logprobs_are_requested_for_prefill:
  574. # CPU GPU sync, after gathering *only* sampled tokens (since
  575. # requesting prompt logprobs leads `sampled_token_ids` to
  576. # include prompt token ids in addition to sampled token ids.)
  577. sample_idx_tensor = torch.tensor(
  578. [sdx for sg in seq_groups for sdx in sg.sample_indices])
  579. pinned_buffer = pinned_buffer.copy_(
  580. sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
  581. else:
  582. # CPU GPU sync
  583. pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
  584. non_blocking=False)
  585. # this will not block as the tensors are already on CPU
  586. samples_list = pinned_buffer.tolist()
  587. skip_sampler_cpu_output = (
  588. frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
  589. # *Don't* skip logprobs pythonization *if*:
  590. # * Any requests require logprobs to be returned in this
  591. # iteration AND
  592. # * These requests are being scheduled in a fashion which
  593. # defers pythonization (i.e. multi-step scheduling.)
  594. do_pythonize_logprobs = (skip_sampler_cpu_output
  595. and any_logprobs_are_requested)
  596. (
  597. prompt_logprobs,
  598. sample_logprobs,
  599. ) = (deferred_pythonize_logprobs(output, sampling_metadata,
  600. logprobs_tensor)
  601. if do_pythonize_logprobs else (None, None))
  602. for sgdx, (seq_group,
  603. sample_result) in enumerate(zip(seq_groups, samples_list)):
  604. if seq_group.sampling_params.logits_processors:
  605. assert len(seq_group.sampling_params.logits_processors) == 0, (
  606. "Logits Processors are not supported in multi-step decoding")
  607. if do_pythonize_logprobs:
  608. assert prompt_logprobs is not None
  609. assert sample_logprobs is not None
  610. (
  611. group_prompt_logprobs,
  612. group_sample_logprobs,
  613. ) = ( # Utilize deferred pythonization results
  614. prompt_logprobs[sgdx],
  615. sample_logprobs[sgdx],
  616. )
  617. elif any_logprobs_are_requested:
  618. (
  619. group_prompt_logprobs,
  620. group_sample_logprobs,
  621. ) = (
  622. # profile_run: use already-computed logprobs
  623. output.outputs[sgdx].prompt_logprobs,
  624. [sample.logprobs for sample in output.outputs[sgdx].samples])
  625. seq_ids = seq_group.seq_ids
  626. next_token_ids = sample_result
  627. parent_ids = [0]
  628. if cache is not None:
  629. completion_seq_group_output: CompletionSequenceGroupOutput = \
  630. cache.cached_completion_seq_group_output.get_object()
  631. completion_seq_group_output.samples.clear()
  632. seq_outputs: List[
  633. SequenceOutput] = completion_seq_group_output.samples
  634. else:
  635. seq_outputs = []
  636. for tdx, (parent_id,
  637. next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
  638. if cache is not None:
  639. seq_output: SequenceOutput = cache.cached_seq_output.get_object(
  640. )
  641. seq_output.parent_seq_id = seq_ids[parent_id]
  642. seq_output.output_token = next_token_id
  643. if any_logprobs_are_requested:
  644. seq_output.logprobs = group_sample_logprobs[tdx]
  645. else:
  646. logprobs = next(iter(seq_output.logprobs.values()))
  647. seq_output.logprobs.clear()
  648. logprobs.logprob = float('inf')
  649. logprobs.rank = None
  650. logprobs.decoded_token = None
  651. seq_output.logprobs[next_token_id] = logprobs
  652. seq_outputs.append(seq_output)
  653. else:
  654. seq_outputs.append(
  655. SequenceOutput(seq_ids[parent_id], next_token_id,
  656. (group_sample_logprobs[tdx]
  657. if any_logprobs_are_requested else {
  658. next_token_id:
  659. Logprob(logprob=float('inf'),
  660. rank=None,
  661. decoded_token=None)
  662. })))
  663. if cache is not None:
  664. completion_seq_group_output.prompt_logprobs = \
  665. group_prompt_logprobs if any_logprobs_are_requested else None
  666. output.outputs.append(completion_seq_group_output)
  667. else:
  668. output.outputs.append(
  669. CompletionSequenceGroupOutput(
  670. seq_outputs, (group_prompt_logprobs
  671. if any_logprobs_are_requested else None)))
  672. assert len(output.outputs) > 0