util.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import time
  2. from contextlib import contextmanager
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  6. SamplerOutput, SequenceGroupMetadata,
  7. SequenceOutput)
  8. SeqId = int
  9. def get_all_num_logprobs(
  10. seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
  11. """Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
  12. If the sampling params do not call for any logprobs, return 0 for that
  13. sequence.
  14. """
  15. all_num_logprobs: List[int] = []
  16. for seq_group_metadata in seq_group_metadata_list:
  17. num_logprobs = seq_group_metadata.sampling_params.logprobs
  18. if num_logprobs is None:
  19. num_logprobs = 0
  20. all_num_logprobs.append(num_logprobs)
  21. return all_num_logprobs
  22. def get_sampled_token_logprobs(
  23. # shape [num_steps, batch_size, vocab_size]
  24. logprob_tensor: torch.Tensor,
  25. sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
  26. ) -> Tuple[torch.Tensor, torch.Tensor]:
  27. """Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
  28. """
  29. num_steps, batch_size, vocab_size = logprob_tensor.shape
  30. selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
  31. torch.arange(batch_size),
  32. sampled_token_ids, ]
  33. expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
  34. -1, -1, vocab_size)
  35. sampled_token_ids_ranks = (logprob_tensor >=
  36. expanded_selected_logprobs).sum(-1)
  37. return sampled_token_ids_ranks, selected_logprobs
  38. def create_sequence_group_output(
  39. token_id: int,
  40. token_id_logprob_rank: int,
  41. token_id_logprob: float,
  42. seq_id: SeqId,
  43. topk_token_ids: List[Optional[int]],
  44. topk_logprobs: List[Optional[float]],
  45. ) -> CompletionSequenceGroupOutput:
  46. """Create a SequenceGroupOutput given the sampling results.
  47. Args:
  48. token_id (int): The sampled token for the sequence.
  49. token_id_logprob_rank (int): The logprob rank of the sampled token.
  50. token_id_logprob (float): The logprob value of the sampled token.
  51. seq_id (int): The sequence id.
  52. topk_token_ids (List[int]): The list of top-k token ids.
  53. topk_logprobs (List[float]): The list of top-k logprobs.
  54. """
  55. # Aphrodite logprobs always include the sampled token. In addition, the
  56. # user may request topk-logprobs (where top-k varies per user up to
  57. # max_logprobs).
  58. logprobs: Dict[Optional[int], Logprob] = {
  59. token_id: Logprob(
  60. logprob=token_id_logprob,
  61. rank=token_id_logprob_rank,
  62. ),
  63. }
  64. logprobs.update({
  65. topk_token_ids[topk_logprob_index]: Logprob(
  66. logprob=topk_logprobs[topk_logprob_index],
  67. rank=topk_logprob_index + 1,
  68. )
  69. for topk_logprob_index, _ in enumerate(topk_token_ids)
  70. })
  71. return CompletionSequenceGroupOutput(
  72. samples=[
  73. SequenceOutput(parent_seq_id=seq_id,
  74. output_token=token_id,
  75. logprobs=logprobs)
  76. ],
  77. # TODO add prompt logprobs support.
  78. prompt_logprobs=None,
  79. )
  80. def split_batch_by_proposal_len(
  81. seq_group_metadata_list: List[SequenceGroupMetadata],
  82. proposal_lens: List[int], select_proposal_len_zero: bool
  83. ) -> Tuple[List[SequenceGroupMetadata], List[int]]:
  84. """Utility function that splits a batch based on whether the proposal len is
  85. zero or not. We should remove this once Aphrodite supports per-sequence
  86. proposal lens in a batch.
  87. """
  88. if select_proposal_len_zero:
  89. predicate = lambda proposal_len: proposal_len == 0
  90. else:
  91. predicate = lambda proposal_len: proposal_len != 0
  92. indices = [
  93. i for i, (_, proposal_len
  94. ) in enumerate(zip(seq_group_metadata_list, proposal_lens))
  95. if predicate(proposal_len)
  96. ]
  97. seq_groups = [
  98. seq_group for seq_group, proposal_len in zip(
  99. seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
  100. ]
  101. return seq_groups, indices
  102. def sampler_output_to_torch(
  103. sampler_output_list: List[SamplerOutput], sampler_transposed: bool
  104. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  105. """Utility function which converts a list of SamplerOutput to tensors.
  106. sampler_transposed here is used as the indicator for whether
  107. we need do additional tensor transpose logic here.
  108. Returns:
  109. sampled_token_ids: torch.Tensor
  110. shape: [batch_size, len(sampler_output_list)]
  111. sampled_token_probs: torch.Tensor
  112. shape: [batch_size, len(sampler_output_list), vocab_size]
  113. """
  114. # shape: [batch_size, num_sampler_output, vocab_size]
  115. sampled_token_probs = torch.stack(
  116. [
  117. sampler_output.sampled_token_probs
  118. for sampler_output in sampler_output_list
  119. ],
  120. dim=0,
  121. )
  122. if sampler_transposed:
  123. sampled_token_probs = sampled_token_probs.transpose(0, 1)
  124. # shape: [batch_size, num_sampler_output, vocab_size]
  125. sampled_token_logprobs = torch.stack(
  126. [sampler_output.logprobs for sampler_output in sampler_output_list],
  127. dim=0,
  128. )
  129. if sampler_transposed:
  130. sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
  131. # shape: [batch_size, num_sampler_output]
  132. sampled_token_ids = torch.stack(
  133. [
  134. sampler_output.sampled_token_ids.flatten()
  135. for sampler_output in sampler_output_list
  136. ],
  137. dim=0,
  138. )
  139. if sampler_transposed:
  140. sampled_token_ids = sampled_token_ids.transpose(0, 1)
  141. if sampler_output_list[0].hidden_states is not None:
  142. # shape: [batch_size, num_sampler_output, hidden_dim]
  143. sampled_hidden_states = torch.stack(
  144. [
  145. sampler_output.hidden_states
  146. for sampler_output in sampler_output_list
  147. ],
  148. dim=0,
  149. )
  150. if sampler_transposed:
  151. sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
  152. else:
  153. sampled_hidden_states = None
  154. return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
  155. sampled_hidden_states)
  156. def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
  157. vocab_size: int, device: str) -> None:
  158. """Helper method which mocks out the GPU tensors in SamplerOutput with dummy
  159. values.
  160. """
  161. values = [
  162. sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
  163. ]
  164. assert all(v is None for v in values) or not any(v is None for v in values)
  165. if not any(v is None for v in values):
  166. # Do nothing if the tensors are already created (usually in unit tests).
  167. return
  168. # Softmax to ensure valid probs.
  169. sampler_output.sampled_token_probs = torch.nn.functional.softmax(
  170. torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
  171. dim=-1)
  172. sampler_output.sampled_token_ids = torch.randint(low=10,
  173. high=100,
  174. size=(batch_size, ),
  175. dtype=torch.long,
  176. device=device)
  177. @contextmanager
  178. def nvtx_range(msg, *args, **kwargs):
  179. """
  180. Context manager / decorator that pushes an NVTX range at the beginning
  181. of its scope, and pops it at the end. If extra arguments are given,
  182. they are passed as arguments to msg.format().
  183. If running with cuda graphs, you must enable nsys cuda graph profiling.
  184. Arguments:
  185. msg (string): message to associate with the range
  186. """
  187. torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
  188. try:
  189. yield
  190. finally:
  191. torch.cuda.nvtx.range_pop()
  192. class Timer:
  193. """Basic timer context manager for measuring CPU time.
  194. """
  195. def __enter__(self):
  196. self.start_time = time.time()
  197. return self
  198. def __exit__(self, exc_type, exc_value, traceback):
  199. self.end_time = time.time()
  200. self.elapsed_time_s = self.end_time - self.start_time
  201. self.elapsed_time_ms = self.elapsed_time_s * 1000