outputs.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import time
  2. from dataclasses import dataclass
  3. from typing import List, Optional, Tuple, Union
  4. from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
  5. SampleLogprobs, SequenceGroup,
  6. SequenceStatus)
  7. from aphrodite.lora.request import LoRARequest
  8. @dataclass
  9. class CompletionOutput:
  10. """The output data of one completion output of a request.
  11. Args:
  12. index: The index of the output in the request.
  13. text: The generated output text.
  14. token_ids: The token IDs of the generated output text.
  15. cumulative_logprob: The cumulative log probability of the generated
  16. output text.
  17. logprobs: The log probabilities of the top probability words at each
  18. position if the logprobs are requested.
  19. finish_reason: The reason why the sequence is finished.
  20. stop_reason: The stop string or token id that caused the completion
  21. to stop, None if the completion finished for some other reason
  22. including encountering the EOS token.
  23. lora_request: The LoRA request that was used to generate the output.
  24. """
  25. index: int
  26. text: str
  27. token_ids: Tuple[int, ...]
  28. cumulative_logprob: Optional[float]
  29. logprobs: Optional[SampleLogprobs]
  30. finish_reason: Optional[str] = None
  31. stop_reason: Union[int, str, None] = None
  32. lora_request: Optional[LoRARequest] = None
  33. def finished(self) -> bool:
  34. return self.finish_reason is not None
  35. def __repr__(self) -> str:
  36. return (f"CompletionOutput(index={self.index}, "
  37. f"text={self.text!r}, "
  38. f"token_ids={self.token_ids}, "
  39. f"cumulative_logprob={self.cumulative_logprob}, "
  40. f"logprobs={self.logprobs}, "
  41. f"finish_reason={self.finish_reason}, "
  42. f"stop_reason={self.stop_reason})")
  43. @dataclass
  44. class EmbeddingOutput:
  45. """The output data of one completion output of a request.
  46. Args:
  47. embedding: The embedding vector, which is a list of floats. The
  48. length of vector depends on the model as listed in the embedding guide.
  49. """
  50. embedding: List[float]
  51. def __repr__(self) -> str:
  52. return (f"EmbeddingOutput("
  53. f"embedding={len(self.embedding)})")
  54. class RequestOutput:
  55. """The output data of a completion request to the LLM.
  56. Args:
  57. request_id: The unique ID of the request.
  58. prompt: The prompt string of the request.
  59. prompt_token_ids: The token IDs of the prompt.
  60. prompt_logprobs: The log probabilities to return per prompt token.
  61. outputs: The output sequences of the request.
  62. finished: Whether the whole request is finished.
  63. metrics: Metrics associated with the request.
  64. lora_request: The LoRA request that was used to generate the output.
  65. """
  66. def __init__(
  67. self,
  68. request_id: str,
  69. prompt: Optional[str],
  70. prompt_token_ids: List[int],
  71. prompt_logprobs: Optional[PromptLogprobs],
  72. outputs: List[CompletionOutput],
  73. finished: bool,
  74. metrics: Optional[RequestMetrics] = None,
  75. lora_request: Optional[LoRARequest] = None,
  76. ) -> None:
  77. self.request_id = request_id
  78. self.prompt = prompt
  79. self.prompt_token_ids = prompt_token_ids
  80. self.prompt_logprobs = prompt_logprobs
  81. self.outputs = outputs
  82. self.finished = finished
  83. self.metrics = metrics
  84. self.lora_request = lora_request
  85. @classmethod
  86. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  87. if seq_group.sampling_params is None:
  88. raise ValueError(
  89. "Sampling parameters are missing for a CompletionRequest.")
  90. # Get the top-n sequences.
  91. n = seq_group.sampling_params.n
  92. seqs = seq_group.get_seqs()
  93. if n == 1:
  94. top_n_seqs = seqs
  95. else:
  96. if seq_group.sampling_params.use_beam_search:
  97. sorting_key = lambda seq: seq.get_beam_search_score(
  98. seq_group.sampling_params.length_penalty)
  99. else:
  100. sorting_key = lambda seq: seq.get_cumulative_logprob()
  101. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  102. top_n_seqs = sorted_seqs[:n]
  103. # Create the outputs.
  104. # NOTE: We need omit logprobs here explicitly because the sequence
  105. # always has the logprobs of the sampled tokens even if the
  106. # logprobs are not requested.
  107. include_logprobs = seq_group.sampling_params.logprobs is not None
  108. text_buffer_length = seq_group.sampling_params.output_text_buffer_length
  109. outputs = [
  110. CompletionOutput(
  111. seqs.index(seq),
  112. seq.get_output_text_to_return(text_buffer_length),
  113. seq.get_output_token_ids(),
  114. seq.get_cumulative_logprob() if include_logprobs else None,
  115. seq.output_logprobs if include_logprobs else None,
  116. SequenceStatus.get_finished_reason(seq.status),
  117. seq.stop_reason) for seq in top_n_seqs
  118. ]
  119. # Every sequence in the sequence group should have the same prompt.
  120. prompt = seq_group.prompt
  121. prompt_token_ids = seq_group.prompt_token_ids
  122. prompt_logprobs = seq_group.prompt_logprobs
  123. finished = seq_group.is_finished()
  124. finished_time = time.time() if finished else None
  125. seq_group.set_finished_time(finished_time)
  126. return cls(
  127. seq_group.request_id,
  128. prompt,
  129. prompt_token_ids,
  130. prompt_logprobs,
  131. outputs,
  132. finished,
  133. seq_group.metrics,
  134. lora_request=seq_group.lora_request,
  135. )
  136. def __repr__(self) -> str:
  137. return (f"RequestOutput(request_id={self.request_id}, "
  138. f"prompt={self.prompt!r}, "
  139. f"prompt_token_ids={self.prompt_token_ids}, "
  140. f"prompt_logprobs={self.prompt_logprobs}, "
  141. f"outputs={self.outputs}, "
  142. f"finished={self.finished}, "
  143. f"metrics={self.metrics}, "
  144. f"lora_request={self.lora_request})")
  145. class EmbeddingRequestOutput:
  146. """
  147. The output data of an embedding request to the LLM.
  148. Args:
  149. request_id (str): A unique identifier for the embedding request.
  150. outputs (EmbeddingOutput): The embedding results for the given input.
  151. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
  152. finished (bool): A flag indicating whether the embedding is completed.
  153. """
  154. def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
  155. prompt_token_ids: List[int], finished: bool):
  156. self.request_id = request_id
  157. self.prompt_token_ids = prompt_token_ids
  158. self.finished = finished
  159. self.outputs = outputs
  160. @classmethod
  161. def from_seq_group(cls,
  162. seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
  163. if seq_group.embeddings is None:
  164. raise ValueError(
  165. "Embeddings are missing in seq_group for EmbeddingRequest.")
  166. output = EmbeddingOutput(seq_group.embeddings)
  167. prompt_token_ids = seq_group.prompt_token_ids
  168. finished = seq_group.is_finished()
  169. return cls(seq_group.request_id, output, prompt_token_ids, finished)
  170. def __repr__(self):
  171. """
  172. Returns a string representation of an EmbeddingRequestOutput instance.
  173. The representation includes the request_id and the number of outputs,
  174. providing a quick overview of the embedding request's results.
  175. Returns:
  176. str: A string representation of the EmbeddingRequestOutput instance.
  177. """
  178. return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
  179. f"outputs={repr(self.outputs)}, "
  180. f"prompt_token_ids={self.prompt_token_ids}, "
  181. f"finished={self.finished})")
  182. class RequestOutputFactory:
  183. @staticmethod
  184. def create(seq_group):
  185. # Determine the type based on a condition, for example:
  186. if hasattr(seq_group,
  187. 'embeddings') and seq_group.embeddings is not None:
  188. return EmbeddingRequestOutput.from_seq_group(seq_group)
  189. else:
  190. return RequestOutput.from_seq_group(seq_group)