outputs.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from typing import List, Optional, Union
  2. from dataclasses import dataclass
  3. import time
  4. from aphrodite.common.sequence import (
  5. PromptLogprobs,
  6. SampleLogprobs,
  7. SequenceGroup,
  8. SequenceStatus,
  9. RequestMetrics,
  10. )
  11. from aphrodite.lora.request import LoRARequest
  12. @dataclass
  13. class CompletionOutput:
  14. """The output data of one completion output of a request.
  15. Args:
  16. index: The index of the output in the request.
  17. text: The generated output text.
  18. token_ids: The token IDs of the generated output text.
  19. cumulative_logprob: The cumulative log probability of the generated
  20. output text.
  21. logprobs: The log probabilities of the top probability words at each
  22. position if the logprobs are requested.
  23. finish_reason: The reason why the sequence is finished.
  24. stop_reason: The stop string or token id that caused the completion
  25. to stop, None if the completion finished for some other reason
  26. including encountering the EOS token.
  27. lora_request: The LoRA request that was used to generate the output.
  28. """
  29. index: int
  30. text: str
  31. token_ids: List[int]
  32. cumulative_logprob: float
  33. logprobs: Optional[SampleLogprobs]
  34. finish_reason: Optional[str] = None
  35. stop_reason: Union[int, str, None] = None
  36. lora_request: Optional[LoRARequest] = None
  37. def finished(self) -> bool:
  38. return self.finish_reason is not None
  39. def __repr__(self) -> str:
  40. return (f"CompletionOutput(index={self.index}, "
  41. f"text={self.text!r}, "
  42. f"token_ids={self.token_ids}, "
  43. f"cumulative_logprob={self.cumulative_logprob}, "
  44. f"logprobs={self.logprobs}, "
  45. f"finish_reason={self.finish_reason}, "
  46. f"stop_reason={self.stop_reason})")
  47. @dataclass
  48. class EmbeddingOutput:
  49. """The output data of one completion output of a request.
  50. Args:
  51. embedding: The embedding vector, which is a list of floats. The
  52. length of vector depends on the model as listed in the embedding guide.
  53. """
  54. embedding: List[float]
  55. def __repr__(self) -> str:
  56. return (f"EmbeddingOutput("
  57. f"embedding={len(self.embedding)})")
  58. class RequestOutput:
  59. """The output data of a completion request to the LLM.
  60. Args:
  61. request_id: The unique ID of the request.
  62. prompt: The prompt string of the request.
  63. prompt_token_ids: The token IDs of the prompt.
  64. prompt_logprobs: The log probabilities to return per prompt token.
  65. outputs: The output sequences of the request.
  66. finished: Whether the whole request is finished.
  67. metrics: Metrics associated with the request.
  68. lora_request: The LoRA request that was used to generate the output.
  69. """
  70. def __init__(
  71. self,
  72. request_id: str,
  73. prompt: Optional[str],
  74. prompt_token_ids: List[int],
  75. prompt_logprobs: Optional[PromptLogprobs],
  76. outputs: List[CompletionOutput],
  77. finished: bool,
  78. metrics: Optional[RequestMetrics] = None,
  79. lora_request: Optional[LoRARequest] = None,
  80. ) -> None:
  81. self.request_id = request_id
  82. self.prompt = prompt
  83. self.prompt_token_ids = prompt_token_ids
  84. self.prompt_logprobs = prompt_logprobs
  85. self.outputs = outputs
  86. self.finished = finished
  87. self.metrics = metrics
  88. self.lora_request = lora_request
  89. @classmethod
  90. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  91. if seq_group.sampling_params is None:
  92. raise ValueError(
  93. "Sampling parameters are missing for a CompletionRequest.")
  94. # Get the top-n sequences.
  95. n = seq_group.sampling_params.n
  96. seqs = seq_group.get_seqs()
  97. if n == 1:
  98. top_n_seqs = seqs
  99. else:
  100. if seq_group.sampling_params.use_beam_search:
  101. sorting_key = lambda seq: seq.get_beam_search_score(
  102. seq_group.sampling_params.length_penalty)
  103. else:
  104. sorting_key = lambda seq: seq.get_cumulative_logprob()
  105. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  106. top_n_seqs = sorted_seqs[:n]
  107. # Create the outputs.
  108. # NOTE: We need omit logprobs here explicitly because the sequence
  109. # always has the logprobs of the sampled tokens even if the
  110. # logprobs are not requested.
  111. include_logprobs = seq_group.sampling_params.logprobs is not None
  112. text_buffer_length = seq_group.sampling_params.output_text_buffer_length
  113. outputs = [
  114. CompletionOutput(seqs.index(seq),
  115. seq.get_output_text_to_return(text_buffer_length),
  116. seq.get_output_token_ids(),
  117. seq.get_cumulative_logprob(),
  118. seq.output_logprobs if include_logprobs else None,
  119. SequenceStatus.get_finished_reason(seq.status),
  120. seq.stop_reason) for seq in top_n_seqs
  121. ]
  122. # Every sequence in the sequence group should have the same prompt.
  123. prompt = seq_group.prompt
  124. prompt_token_ids = seq_group.prompt_token_ids
  125. prompt_logprobs = seq_group.prompt_logprobs
  126. finished = seq_group.is_finished()
  127. finished_time = time.time() if finished else None
  128. seq_group.set_finished_time(finished_time)
  129. return cls(
  130. seq_group.request_id,
  131. prompt,
  132. prompt_token_ids,
  133. prompt_logprobs,
  134. outputs,
  135. finished,
  136. seq_group.metrics,
  137. lora_request=seq_group.lora_request,
  138. )
  139. def __repr__(self) -> str:
  140. return (f"RequestOutput(request_id={self.request_id}, "
  141. f"prompt={self.prompt!r}, "
  142. f"prompt_token_ids={self.prompt_token_ids}, "
  143. f"prompt_logprobs={self.prompt_logprobs}, "
  144. f"outputs={self.outputs}, "
  145. f"finished={self.finished}, "
  146. f"metrics={self.metrics}, "
  147. f"lora_request={self.lora_request})")
  148. class EmbeddingRequestOutput:
  149. """
  150. The output data of an embedding request to the LLM.
  151. Args:
  152. request_id (str): A unique identifier for the embedding request.
  153. outputs (EmbeddingOutput): The embedding results for the given input.
  154. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
  155. finished (bool): A flag indicating whether the embedding is completed.
  156. """
  157. def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
  158. prompt_token_ids: List[int], finished: bool):
  159. self.request_id = request_id
  160. self.prompt_token_ids = prompt_token_ids
  161. self.finished = finished
  162. self.outputs = outputs
  163. @classmethod
  164. def from_seq_group(cls,
  165. seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
  166. if seq_group.embeddings is None:
  167. raise ValueError(
  168. "Embeddings are missing in seq_group for EmbeddingRequest.")
  169. output = EmbeddingOutput(seq_group.embeddings)
  170. prompt_token_ids = seq_group.prompt_token_ids
  171. finished = seq_group.is_finished()
  172. return cls(seq_group.request_id, output, prompt_token_ids, finished)
  173. def __repr__(self):
  174. """
  175. Returns a string representation of an EmbeddingRequestOutput instance.
  176. The representation includes the request_id and the number of outputs,
  177. providing a quick overview of the embedding request's results.
  178. Returns:
  179. str: A string representation of the EmbeddingRequestOutput instance.
  180. """
  181. return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
  182. f"outputs={repr(self.outputs)}, "
  183. f"prompt_token_ids={self.prompt_token_ids}, "
  184. f"finished={self.finished})")
  185. class RequestOutputFactory:
  186. @staticmethod
  187. def create(seq_group):
  188. # Determine the type based on a condition, for example:
  189. if hasattr(seq_group,
  190. 'embeddings') and seq_group.embeddings is not None:
  191. return EmbeddingRequestOutput.from_seq_group(seq_group)
  192. else:
  193. return RequestOutput.from_seq_group(seq_group)