multi_step_model_runner.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. import dataclasses
  2. import functools
  3. from dataclasses import dataclass, field
  4. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
  5. Union)
  6. try:
  7. from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
  8. except ModuleNotFoundError:
  9. # aphrodite_flash_attn is not installed, use the identical ROCm FA metadata
  10. from aphrodite.attention.backends.rocm_flash_attn import (
  11. ROCmFlashAttentionMetadata as FlashAttentionMetadata,
  12. )
  13. import torch
  14. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  15. IntermediateTensors, Logprob,
  16. SequenceGroupMetadata, SequenceOutput)
  17. from aphrodite.common.utils import PyObjectCache
  18. from aphrodite.distributed import get_pp_group
  19. from aphrodite.modeling.layers.sampler import (PromptLogprobs, SampleLogprobs,
  20. SamplerOutput, SamplingMetadata,
  21. get_logprobs,
  22. get_pythonized_sample_results)
  23. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  24. from aphrodite.worker.model_runner import (
  25. GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
  26. from aphrodite.worker.model_runner_base import (
  27. BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
  28. _init_frozen_model_input_from_tensor_dict,
  29. _init_sampling_metadata_from_tensor_dict)
  30. if TYPE_CHECKING:
  31. from aphrodite.attention.backends.abstract import AttentionBackend
  32. def seq_output_builder():
  33. return SequenceOutput(
  34. 0, 0,
  35. {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})
  36. def completion_seq_group_output_builder():
  37. return CompletionSequenceGroupOutput([], None)
  38. # Used by pythonization to reduce python object allocations
  39. class PythonizationCache:
  40. def __init__(self):
  41. self.cached_seq_output = PyObjectCache(seq_output_builder)
  42. self.cached_completion_seq_group_output = PyObjectCache(
  43. completion_seq_group_output_builder)
  44. def reset(self):
  45. self.cached_seq_output.reset()
  46. self.cached_completion_seq_group_output.reset()
  47. @dataclass
  48. class ModelOutput:
  49. """The output of a single model forward pass.
  50. The sampler_output_ready_event is set when the tensors in
  51. sampler_output are ready (the model+sampler forward pass has
  52. completed). We use the event to synchronize the GPU->CPU transfer,
  53. which we want to only run when the data has been written to the
  54. GPU tensors. Until the event is ready, the tensors in sampler_output
  55. will have garbage data.
  56. There are two scenarios:
  57. 1. The output tensors are ready and we can pythonize them immediately.
  58. 2. The output tensors are not ready and we need to wait for the event to be
  59. ready.
  60. """
  61. sampler_output: SamplerOutput
  62. sampler_output_ready_event: torch.cuda.Event
  63. sampled_token_ids: Optional[torch.Tensor] = None
  64. pythonized: bool = False
  65. # On-device tensor containing the logprobs of each token.
  66. logprobs: Optional["torch.Tensor"] = None
  67. pythonization_cache: Optional[PythonizationCache] = None
  68. def pythonize(
  69. self,
  70. input_metadata: "StatefulModelInput",
  71. copy_stream: torch.cuda.Stream,
  72. pinned_sampled_token_buffer: torch.Tensor,
  73. ) -> None:
  74. """Pythonize the output. Blocking."""
  75. if not self.pythonized:
  76. self._pythonize_sampler_output(
  77. input_metadata, copy_stream, pinned_sampled_token_buffer, True
  78. )
  79. self.pythonized = True
  80. def maybe_pythonize(
  81. self,
  82. input_metadata: "StatefulModelInput",
  83. copy_stream: torch.cuda.Stream,
  84. pinned_sampled_token_buffer: torch.Tensor,
  85. ) -> None:
  86. """Pythonize the output if ready, else return None. Non-blocking."""
  87. if not self.pythonized:
  88. self.pythonized = self._pythonize_sampler_output(
  89. input_metadata, copy_stream, pinned_sampled_token_buffer, False
  90. )
  91. def _pythonize_sampler_output(
  92. self,
  93. input_metadata: "StatefulModelInput",
  94. copy_stream: torch.cuda.Stream,
  95. pinned_sampled_token_buffer: torch.Tensor,
  96. blocking: bool,
  97. ) -> bool:
  98. """
  99. If blocking is set, will block until the forward pass for the output is
  100. ready and pythonize the output. Upon completing Pythonization, erases
  101. self.logprobs (note that a non-blocking call that is performed when
  102. the sampler output is not yet ready, will not erase self.logprobs.)
  103. """
  104. assert self.sampled_token_ids is not None
  105. if not blocking and not self.sampler_output_ready_event.query():
  106. return False
  107. if blocking:
  108. self.sampler_output_ready_event.synchronize()
  109. with torch.cuda.stream(copy_stream):
  110. _pythonize_sampler_output(
  111. input_metadata,
  112. self.sampler_output,
  113. pinned_sampled_token_buffer,
  114. self.sampled_token_ids, self.logprobs,
  115. self.pythonization_cache)
  116. # Erase the logprobs GPU-side tensor.
  117. # Note that although _pythonize_sampler_output() runs in its
  118. # own CUDA stream, nonetheless _pythonize_sampler_output()
  119. # cannot return until Pythonization is complete; therefore
  120. # we know that by the time the CPU reaches this point,
  121. # `self.logprobs` is no longer needed.
  122. self.logprobs = None
  123. return True
  124. @dataclass(frozen=False)
  125. class StatefulModelInput(BroadcastableModelInput):
  126. # actual frozen model input dataclass passed to _base_model_runner
  127. frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None
  128. # list of model outputs for each step, may not be all pythonized
  129. cached_outputs: List[ModelOutput] = field(default_factory=list)
  130. # used to pass sampled token ids from the last step to the current step for
  131. # TP workers. Used to append to end of outputs and used by advance_step
  132. last_sampled_token_ids: Optional[torch.Tensor] = None
  133. current_step: int = 0
  134. is_multi_step: bool = True
  135. is_last_step: bool = False
  136. is_first_multi_step: bool = False
  137. # ping-pong data structures for multi-step to wait on the previous step
  138. step_cuda_events: List[torch.cuda.Event] = field(
  139. default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2
  140. )
  141. num_seqs: int = -1
  142. num_queries: int = -1
  143. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  144. assert self.frozen_model_input is not None
  145. tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
  146. new_tensor_dict = {
  147. "last_sampled_token_ids": self.last_sampled_token_ids,
  148. "current_step": self.current_step,
  149. "is_multi_step": self.is_multi_step,
  150. "is_last_step": self.is_last_step,
  151. "is_first_multi_step": self.is_first_multi_step,
  152. "num_seqs": self.num_seqs,
  153. "num_queries": self.num_queries,
  154. }
  155. tensor_dict.update(new_tensor_dict)
  156. return tensor_dict
  157. @classmethod
  158. def from_broadcasted_tensor_dict(
  159. cls,
  160. tensor_dict: Dict[str, Any],
  161. attn_backend: Optional["AttentionBackend"] = None,
  162. ) -> "StatefulModelInput":
  163. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  164. if attn_backend is not None:
  165. tensor_dict = _init_attn_metadata_from_tensor_dict(
  166. attn_backend, tensor_dict
  167. )
  168. tensor_dict = _init_frozen_model_input_from_tensor_dict(
  169. ModelInputForGPUWithSamplingMetadata, tensor_dict
  170. )
  171. return cls(**tensor_dict)
  172. def record_step_event(self, current_stream: torch.cuda.Stream):
  173. # record the event for the current step so that the next step can sync
  174. # on it. We modulo by 2 to keep the events in a circular buffer and
  175. # support any attn backends that may be supported in the future. ie
  176. # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
  177. self.step_cuda_events[self.current_step & 1] = torch.cuda.Event(
  178. blocking=True
  179. )
  180. self.step_cuda_events[self.current_step & 1].record(current_stream)
  181. def wait_previous_step(self):
  182. # These cuda events are an explicit synchronization to ensure that
  183. # advance_step() (for other attn backends that may be supported in the
  184. # future) do not clobber any data structures that is also used by any
  185. # enqueued forwards steps. For distributed case, only a single event is
  186. # needed, but for single GPU case, since we can let the CPU run much
  187. # further ahead, two events allow us to overlap the advance_step with
  188. # the previous forward (ie using two DecodeWrappers for flashinfer
  189. # backend)
  190. self.step_cuda_events[(self.current_step + 1) & 1].wait()
  191. def add_sampler_output(
  192. self,
  193. sampler_output: SamplerOutput,
  194. sampled_token_ids: Optional[torch.Tensor] = None,
  195. ):
  196. self.cached_outputs.append(
  197. ModelOutput(
  198. sampler_output=sampler_output,
  199. sampler_output_ready_event=None,
  200. sampled_token_ids=sampled_token_ids,
  201. pythonized=False,
  202. )
  203. )
  204. # MutableModelInputForGPUWithMultiStepMetadata is not subclass of
  205. # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
  206. # metadata
  207. # mypy: disable-error-code=type-var
  208. class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
  209. # mypy: enable-error-code=type-var
  210. def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
  211. super().__init__(*args, **kwargs)
  212. # uses the base model runner to execute the model and wraps it with
  213. # multi-step logic
  214. self._base_model_runner: GPUModelRunnerBase = base_model_runner
  215. self.is_multi_step = self.scheduler_config.is_multi_step
  216. # used to copy tensors from GPU to CPU asynchronously
  217. self._copy_stream = torch.cuda.Stream()
  218. self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
  219. self.pythonization_cache = PythonizationCache()
  220. def make_model_input_from_broadcasted_tensor_dict(
  221. self, tensor_dict: Dict[str, Any]
  222. ) -> StatefulModelInput:
  223. model_input = StatefulModelInput.from_broadcasted_tensor_dict(
  224. tensor_dict,
  225. attn_backend=self.attn_backend,
  226. )
  227. return model_input
  228. def prepare_model_input(
  229. self,
  230. seq_group_metadata_list: List[SequenceGroupMetadata],
  231. virtual_engine: int = 0,
  232. finished_requests_ids: Optional[List[str]] = None,
  233. ) -> StatefulModelInput:
  234. frozen_model_input = self._base_model_runner.prepare_model_input(
  235. seq_group_metadata_list, virtual_engine, finished_requests_ids
  236. )
  237. model_input = StatefulModelInput(
  238. frozen_model_input=frozen_model_input,
  239. num_seqs=len(frozen_model_input.seq_lens),
  240. num_queries=len(frozen_model_input.query_lens),
  241. )
  242. return model_input
  243. def _async_process_outputs(self, model_input: StatefulModelInput,
  244. output_proc_callback: Callable):
  245. # Proceed with pythonization and output_proc in order.
  246. # Stop on the first one that fails to pythonize
  247. output_proc_callback()
  248. cont = True
  249. for model_output in model_input.cached_outputs:
  250. if not model_output.pythonized:
  251. model_output.maybe_pythonize(model_input, self._copy_stream,
  252. self.pinned_sampled_token_ids)
  253. if model_output.pythonized:
  254. ctx = output_proc_callback.keywords["ctx"]
  255. ctx.append_output(
  256. outputs=[model_output.sampler_output],
  257. seq_group_metadata_list=ctx.seq_group_metadata_list,
  258. scheduler_outputs=ctx.scheduler_outputs,
  259. is_async=False,
  260. is_last_step=False)
  261. output_proc_callback()
  262. else:
  263. cont = False
  264. if not cont:
  265. break
  266. def _final_process_outputs(self, model_input: StatefulModelInput,
  267. output_proc_callback: Optional[Callable]):
  268. assert model_input.frozen_model_input is not None
  269. has_async_callback = output_proc_callback is not None
  270. outputs = []
  271. for output_id in range(len(model_input.cached_outputs)):
  272. output = model_input.cached_outputs[output_id]
  273. is_last_step = output_id == len(model_input.cached_outputs) - 1
  274. # For non-async case:
  275. # -- We simply add the outputs
  276. # For async case:
  277. # -- Invoke callback, pythonize, add to callback queue and repeat
  278. # -- For last output, just add to callback queue
  279. if has_async_callback:
  280. assert output_proc_callback is not None
  281. # Invoke callback before pythonize (to overlap with GPU)
  282. output_proc_callback()
  283. # Pythonize
  284. if not output.pythonized:
  285. output.pythonize(model_input, self._copy_stream,
  286. self.pinned_sampled_token_ids)
  287. # For non last step, add to callback queue to chain
  288. # callbacks=>pythonize pairs (for GPU overlap)
  289. if not is_last_step:
  290. ctx = output_proc_callback.keywords[ # type: ignore
  291. "ctx"] # type: ignore
  292. ctx.append_output(
  293. outputs=[output.sampler_output],
  294. seq_group_metadata_list=ctx.
  295. seq_group_metadata_list,
  296. scheduler_outputs=ctx.scheduler_outputs,
  297. is_async=False,
  298. is_last_step=False)
  299. else:
  300. outputs.append(output.sampler_output)
  301. else:
  302. output.pythonize(model_input, self._copy_stream,
  303. self.pinned_sampled_token_ids)
  304. outputs.append(output.sampler_output)
  305. return outputs
  306. @torch.inference_mode()
  307. def execute_model(
  308. self,
  309. model_input: StatefulModelInput,
  310. kv_caches: List[torch.Tensor],
  311. intermediate_tensors: Optional[IntermediateTensors] = None,
  312. num_steps: int = 1,
  313. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  314. """
  315. Execute the model for a single step and update multi-step
  316. metadata
  317. """
  318. assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
  319. frozen_model_input = model_input.frozen_model_input
  320. assert frozen_model_input is not None
  321. # path for warm up runs
  322. if not model_input.is_multi_step:
  323. return self._base_model_runner.execute_model(
  324. frozen_model_input, kv_caches, intermediate_tensors, num_steps
  325. )
  326. # make sure we skip the sampler on the lask rank and only pythonize
  327. # if CPU is ahead.
  328. if self.is_driver_worker and get_pp_group().is_last_rank:
  329. if self.pinned_sampled_token_ids is None:
  330. self.pinned_sampled_token_ids = torch.zeros(
  331. (self.scheduler_config.max_num_seqs, 1),
  332. dtype=torch.long,
  333. device="cpu",
  334. pin_memory=True,
  335. )
  336. self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
  337. True
  338. )
  339. if frozen_model_input.sampling_metadata:
  340. frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
  341. True
  342. )
  343. # some pre-execute model logic for multi-step:
  344. # - if it's the first step, we need to reset the sampling tensors
  345. # - if it's not the first step, we need to advance the step using the
  346. # appended sampler output from last iteration
  347. # - also maybe pythonize if CPU is ahead of GPU
  348. current_stream = torch.cuda.current_stream()
  349. if not model_input.is_first_multi_step:
  350. # Explicitly block on the previous step's forward to make sure we
  351. # don't clobber any GPU tensors still in use.
  352. # This is not needed for flashattn backend, but for other attn
  353. # backends such as flashinfer that performs extra CPU operations on
  354. # input metadata we may need to synchronize any CPU operations that
  355. # might clobber enqueued forwards. (prevents CPU from running too
  356. # far ahead if needed)
  357. model_input.wait_previous_step()
  358. model_input = self._advance_step(
  359. model_input, model_input.cached_outputs[-1].sampler_output
  360. )
  361. output_proc_callback = None
  362. if frozen_model_input.async_callback is not None:
  363. output_proc_callback = frozen_model_input.async_callback
  364. assert output_proc_callback is not None
  365. async_callback = functools.partial(
  366. self._async_process_outputs,
  367. model_input=model_input,
  368. output_proc_callback=output_proc_callback)
  369. frozen_model_input = dataclasses.replace( # type: ignore
  370. model_input.frozen_model_input,
  371. async_callback=async_callback)
  372. assert frozen_model_input is not None
  373. # Execute the model
  374. output = self._base_model_runner.execute_model(
  375. frozen_model_input, kv_caches, intermediate_tensors, num_steps=1
  376. )
  377. # record the event for the current step so that the next step can sync
  378. model_input.record_step_event(current_stream)
  379. if get_pp_group().is_last_rank and self.is_driver_worker:
  380. assert (
  381. len(output) == 1
  382. ), "MultiStepModelRunner requires single-step base_models"
  383. # event for the pythonization so that we only pythonize if the
  384. # tensors are ready. May be able to be combined with the step event
  385. output_ready_event = torch.cuda.Event()
  386. output_ready_event.record(current_stream)
  387. if self.parallel_config.pipeline_parallel_size > 1:
  388. output[0].sampled_token_ids_cpu = output[
  389. 0
  390. ].sampled_token_ids.cpu()
  391. model_input.cached_outputs.append(
  392. ModelOutput(
  393. output[0],
  394. output_ready_event,
  395. output[0].sampled_token_ids, False,
  396. output[0].logprobs, self.pythonization_cache))
  397. # These GPU tensors are not required by multi-step;
  398. # erase them to ensure they are not pythonized or
  399. # transferred to CPU
  400. output[0].sampled_token_ids = None
  401. output[0].sampled_token_probs = None
  402. output[0].logprobs = None
  403. # Pythonize the output if CPU is ahead and the previous step is
  404. # ready.
  405. if frozen_model_input.async_callback is None:
  406. for model_output in model_input.cached_outputs:
  407. model_output.maybe_pythonize(model_input,
  408. self._copy_stream,
  409. self.pinned_sampled_token_ids)
  410. model_input.current_step += 1
  411. if not get_pp_group().is_last_rank:
  412. # Should be IntermediateTensors
  413. assert isinstance(output, IntermediateTensors)
  414. return output
  415. if not self.is_driver_worker:
  416. return []
  417. # Pythonize the output and block if needed since it is the last step
  418. if model_input.is_last_step:
  419. outputs = self._final_process_outputs(model_input,
  420. output_proc_callback)
  421. self.pythonization_cache.reset()
  422. return outputs
  423. # should be [SamplerOutput]
  424. return output
  425. def _update_sampling_metadata(
  426. self, sampling_metadata, num_seqs, num_queries
  427. ):
  428. assert sampling_metadata.num_prompts == 0
  429. assert len(sampling_metadata.seq_groups) == num_queries
  430. assert sampling_metadata.selected_token_indices.shape == (num_queries,)
  431. # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
  432. # Verify that all sequences are decodes
  433. for i in range(num_queries):
  434. seq_group = sampling_metadata.seq_groups[i]
  435. assert seq_group.is_prompt is False # No prompt
  436. assert seq_group.prompt_logprob_indices == [] # No prompt
  437. assert seq_group.sample_indices == [i] # Simple
  438. assert seq_group.seq_len is None # Decode
  439. assert seq_group.query_len is None # Decode
  440. def _advance_step(
  441. self, model_input: StatefulModelInput, out: SamplerOutput
  442. ) -> StatefulModelInput:
  443. frozen_model_input = model_input.frozen_model_input
  444. assert frozen_model_input is not None
  445. assert frozen_model_input.attn_metadata is not None
  446. num_seqs = model_input.num_seqs
  447. num_queries = model_input.num_queries
  448. assert num_seqs > 0
  449. assert num_queries > 0
  450. assert num_seqs >= num_queries
  451. attn_metadata = frozen_model_input.attn_metadata
  452. assert isinstance(attn_metadata, FlashAttentionMetadata)
  453. attn_metadata.advance_step(
  454. frozen_model_input,
  455. model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
  456. num_seqs, num_queries)
  457. if frozen_model_input.seq_lens is not None:
  458. for i in range(num_queries):
  459. frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
  460. return model_input
  461. def load_model(self) -> None:
  462. return self._base_model_runner.load_model()
  463. def save_sharded_state(
  464. self,
  465. path: str,
  466. pattern: Optional[str] = None,
  467. max_size: Optional[int] = None,
  468. ) -> None:
  469. return self._base_model_runner.save_sharded_state(
  470. path, pattern, max_size
  471. )
  472. def save_tensorized_model(
  473. self, tensorizer_config: TensorizerConfig
  474. ) -> None:
  475. return self._base_model_runner.save_tensorized_model(tensorizer_config)
  476. def profile_run(self) -> None:
  477. return self._base_model_runner.profile_run()
  478. def remove_all_loras(self):
  479. return self._base_model_runner.remove_all_loras()
  480. def capture_model(self, kv_caches: List[List]) -> None:
  481. return self._base_model_runner.capture_model(kv_caches)
  482. @property
  483. def vocab_size(self) -> int:
  484. return self._base_model_runner.vocab_size
  485. DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
  486. Optional[List[SampleLogprobs]]]
  487. def deferred_pythonize_logprobs(
  488. output: SamplerOutput,
  489. sampling_metadata: SamplingMetadata,
  490. logprobs_tensor: Optional[torch.Tensor],
  491. ) -> DeferredLogprobsReturnType:
  492. """Perform deferred logprob Pythonization.
  493. 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
  494. 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
  495. utilizing the Pythonized sampler result computed in step 1.
  496. These deferred computations are not required for single-step scheduling
  497. or the `profile_run()` phase of multi-step scheduling.
  498. Args:
  499. output: sampler output (under deferred Pythonization)
  500. sampling_metadata
  501. Returns:
  502. prompt_logprobs (CPU), sample_logprobs (CPU)
  503. """
  504. # - Deferred pythonization of sample result
  505. sampler_result = get_pythonized_sample_results(
  506. output.deferred_sample_results_args)
  507. # - Erase the GPU-side deferred sample_result
  508. # computation args to ensure it is never
  509. # pythonized or transferred to CPU
  510. output.deferred_sample_results_args = None
  511. # - Deferred pythonization of logprobs
  512. (
  513. prompt_logprobs,
  514. sample_logprobs,
  515. ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
  516. assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
  517. assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
  518. return prompt_logprobs, sample_logprobs
  519. def _pythonize_sampler_output(
  520. model_input: StatefulModelInput,
  521. output: SamplerOutput,
  522. pinned_sampled_token_buffer: torch.Tensor,
  523. sampled_token_ids: torch.Tensor,
  524. logprobs_tensor: Optional[torch.Tensor],
  525. cache: Optional[PythonizationCache],
  526. ) -> None:
  527. """ This function is only called when the output tensors are ready.
  528. See :class:`ModelOutput`.
  529. Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
  530. adding a Pythonized output data structure
  531. (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
  532. Args:
  533. model_input
  534. output: sampler output
  535. pinned_sampled_token_token_buffer: CPU-side pinned memory
  536. (receives copy of
  537. GPU-side token buffer.)
  538. sampled_token_ids: GPU-side token buffer
  539. logprobs_tensor: GPU-side tensor containing
  540. logprobs computed during sampling
  541. """
  542. assert model_input.frozen_model_input is not None
  543. frozen_model_input = model_input.frozen_model_input
  544. assert frozen_model_input.sampling_metadata is not None
  545. # samples generation should have been skipped
  546. assert not output.outputs
  547. pinned_buffer = pinned_sampled_token_buffer[: model_input.num_queries]
  548. # CPU GPU sync
  549. pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
  550. # this will not block as the tensors are already on CPU
  551. samples_list = pinned_buffer.tolist()
  552. sampling_metadata = frozen_model_input.sampling_metadata
  553. skip_sampler_cpu_output = (
  554. frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
  555. # We are guaranteed output tensors are ready, so it is safe to
  556. # pythonize the sampler output & obtain CPU-side logprobs.
  557. #
  558. # However this computation may be skipped entirely
  559. # if no pythonization was deferred.
  560. seq_groups = sampling_metadata.seq_groups
  561. logprobs_are_requested = any([
  562. sg.sampling_params.logprobs is not None
  563. or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
  564. ])
  565. do_pythonize_logprobs = (skip_sampler_cpu_output
  566. and logprobs_are_requested)
  567. (
  568. prompt_logprobs,
  569. sample_logprobs,
  570. ) = (deferred_pythonize_logprobs(output, sampling_metadata,
  571. logprobs_tensor)
  572. if do_pythonize_logprobs else (None, None))
  573. for sgdx, (seq_group,
  574. sample_result) in enumerate(zip(seq_groups, samples_list)):
  575. if seq_group.sampling_params.logits_processors:
  576. assert len(seq_group.sampling_params.logits_processors) == 0, (
  577. "Logits Processors are not supported in multi-step decoding")
  578. if do_pythonize_logprobs:
  579. assert prompt_logprobs is not None
  580. assert sample_logprobs is not None
  581. (
  582. group_prompt_logprobs,
  583. group_sample_logprobs,
  584. ) = ( # Utilize deferred pythonization results
  585. prompt_logprobs[sgdx],
  586. sample_logprobs[sgdx],
  587. )
  588. elif logprobs_are_requested:
  589. (
  590. group_prompt_logprobs,
  591. group_sample_logprobs,
  592. ) = (
  593. # profile_run: use already-computed logprobs
  594. output.outputs[sgdx].prompt_logprobs,
  595. [sample.logprobs for sample in output.outputs[sgdx].samples])
  596. seq_ids = seq_group.seq_ids
  597. next_token_ids = sample_result
  598. parent_ids = [0]
  599. if cache is not None:
  600. completion_seq_group_output: CompletionSequenceGroupOutput = \
  601. cache.cached_completion_seq_group_output.get_object()
  602. completion_seq_group_output.samples.clear()
  603. seq_outputs: List[
  604. SequenceOutput] = completion_seq_group_output.samples
  605. else:
  606. seq_outputs = []
  607. for tdx, (parent_id,
  608. next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
  609. if cache is not None:
  610. seq_output: SequenceOutput = cache.cached_seq_output.get_object(
  611. )
  612. seq_output.parent_seq_id = seq_ids[parent_id]
  613. seq_output.output_token = next_token_id
  614. if logprobs_are_requested:
  615. seq_output.logprobs = group_sample_logprobs[tdx]
  616. else:
  617. logprobs = next(iter(seq_output.logprobs.values()))
  618. seq_output.logprobs.clear()
  619. logprobs.logprob = float('inf')
  620. logprobs.rank = None
  621. logprobs.decoded_token = None
  622. seq_output.logprobs[next_token_id] = logprobs
  623. seq_outputs.append(seq_output)
  624. else:
  625. seq_outputs.append(
  626. SequenceOutput(seq_ids[parent_id], next_token_id,
  627. (group_sample_logprobs[tdx]
  628. if logprobs_are_requested else {
  629. next_token_id:
  630. Logprob(logprob=float('inf'),
  631. rank=None,
  632. decoded_token=None)
  633. })))
  634. if cache is not None:
  635. completion_seq_group_output.prompt_logprobs = \
  636. group_prompt_logprobs if logprobs_are_requested else None
  637. output.outputs.append(completion_seq_group_output)
  638. else:
  639. output.outputs.append(
  640. CompletionSequenceGroupOutput(
  641. seq_outputs, (group_prompt_logprobs
  642. if logprobs_are_requested else None)))
  643. assert len(output.outputs) > 0