outputs.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import time
  2. from dataclasses import dataclass
  3. from typing import List, Optional, Tuple, Union
  4. from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
  5. SampleLogprobs, SequenceGroup,
  6. SequenceStatus)
  7. from aphrodite.lora.request import LoRARequest
  8. @dataclass
  9. class CompletionOutput:
  10. """The output data of one completion output of a request.
  11. Args:
  12. index: The index of the output in the request.
  13. text: The generated output text.
  14. token_ids: The token IDs of the generated output text.
  15. cumulative_logprob: The cumulative log probability of the generated
  16. output text.
  17. logprobs: The log probabilities of the top probability words at each
  18. position if the logprobs are requested.
  19. finish_reason: The reason why the sequence is finished.
  20. stop_reason: The stop string or token id that caused the completion
  21. to stop, None if the completion finished for some other reason
  22. including encountering the EOS token.
  23. lora_request: The LoRA request that was used to generate the output.
  24. """
  25. index: int
  26. text: str
  27. token_ids: Tuple[int, ...]
  28. cumulative_logprob: Optional[float]
  29. logprobs: Optional[SampleLogprobs]
  30. finish_reason: Optional[str] = None
  31. stop_reason: Union[int, str, None] = None
  32. lora_request: Optional[LoRARequest] = None
  33. def finished(self) -> bool:
  34. return self.finish_reason is not None
  35. def __repr__(self) -> str:
  36. return (f"CompletionOutput(index={self.index}, "
  37. f"text={self.text!r}, "
  38. f"token_ids={self.token_ids}, "
  39. f"cumulative_logprob={self.cumulative_logprob}, "
  40. f"logprobs={self.logprobs}, "
  41. f"finish_reason={self.finish_reason}, "
  42. f"stop_reason={self.stop_reason})")
  43. @dataclass
  44. class EmbeddingOutput:
  45. """The output data of one completion output of a request.
  46. Args:
  47. embedding: The embedding vector, which is a list of floats. The
  48. length of vector depends on the model as listed in the embedding guide.
  49. """
  50. embedding: List[float]
  51. def __repr__(self) -> str:
  52. return (f"EmbeddingOutput("
  53. f"embedding={len(self.embedding)})")
  54. class RequestOutput:
  55. """The output data of a completion request to the LLM.
  56. Args:
  57. request_id: The unique ID of the request.
  58. prompt: The prompt string of the request.
  59. For encoder/decoder models, this is the
  60. decoder input prompt.
  61. prompt_token_ids: The token IDs of the prompt.
  62. For encoder/decoder models, this is the
  63. decoder input prompt token ids.
  64. prompt_logprobs: The log probabilities to return per prompt token.
  65. outputs: The output sequences of the request.
  66. finished: Whether the whole request is finished.
  67. metrics: Metrics associated with the request.
  68. lora_request: The LoRA request that was used to generate the output.
  69. encoder_prompt: The encoder prompt string of the request;
  70. None if decoder-only
  71. encoder_prompt_token_ids: The token IDs of the encoder prompt;
  72. None if decoder-only
  73. """
  74. def __init__(
  75. self,
  76. request_id: str,
  77. prompt: Optional[str],
  78. prompt_token_ids: List[int],
  79. prompt_logprobs: Optional[PromptLogprobs],
  80. outputs: List[CompletionOutput],
  81. finished: bool,
  82. metrics: Optional[RequestMetrics] = None,
  83. lora_request: Optional[LoRARequest] = None,
  84. encoder_prompt: Optional[str] = None,
  85. encoder_prompt_token_ids: Optional[List[int]] = None,
  86. ) -> None:
  87. self.request_id = request_id
  88. self.prompt = prompt
  89. self.prompt_token_ids = prompt_token_ids
  90. self.prompt_logprobs = prompt_logprobs
  91. self.outputs = outputs
  92. self.finished = finished
  93. self.metrics = metrics
  94. self.lora_request = lora_request
  95. self.encoder_prompt = encoder_prompt
  96. self.encoder_prompt_token_ids = encoder_prompt_token_ids
  97. @classmethod
  98. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  99. if seq_group.sampling_params is None:
  100. raise ValueError(
  101. "Sampling parameters are missing for a CompletionRequest.")
  102. # Get the top-n sequences.
  103. n = seq_group.sampling_params.n
  104. seqs = seq_group.get_seqs()
  105. if n == 1:
  106. top_n_seqs = seqs
  107. else:
  108. if seq_group.sampling_params.use_beam_search:
  109. sorting_key = lambda seq: seq.get_beam_search_score(
  110. seq_group.sampling_params.length_penalty)
  111. else:
  112. sorting_key = lambda seq: seq.get_cumulative_logprob()
  113. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  114. top_n_seqs = sorted_seqs[:n]
  115. # Create the outputs.
  116. # NOTE: We need omit logprobs here explicitly because the sequence
  117. # always has the logprobs of the sampled tokens even if the
  118. # logprobs are not requested.
  119. include_logprobs = seq_group.sampling_params.logprobs is not None
  120. text_buffer_length = seq_group.sampling_params.output_text_buffer_length
  121. outputs = [
  122. CompletionOutput(
  123. seqs.index(seq),
  124. seq.get_output_text_to_return(text_buffer_length),
  125. seq.get_output_token_ids(),
  126. seq.get_cumulative_logprob() if include_logprobs else None,
  127. seq.output_logprobs if include_logprobs else None,
  128. SequenceStatus.get_finished_reason(seq.status),
  129. seq.stop_reason) for seq in top_n_seqs
  130. ]
  131. # Every sequence in the sequence group should have the same prompt.
  132. prompt = seq_group.prompt
  133. prompt_token_ids = seq_group.prompt_token_ids
  134. encoder_prompt = seq_group.encoder_prompt
  135. encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
  136. prompt_logprobs = seq_group.prompt_logprobs
  137. finished = seq_group.is_finished()
  138. finished_time = time.time() if finished else None
  139. seq_group.set_finished_time(finished_time)
  140. return cls(
  141. seq_group.request_id,
  142. prompt,
  143. prompt_token_ids,
  144. prompt_logprobs,
  145. outputs,
  146. finished,
  147. seq_group.metrics,
  148. lora_request=seq_group.lora_request,
  149. encoder_prompt=encoder_prompt,
  150. encoder_prompt_token_ids=encoder_prompt_token_ids,
  151. )
  152. def __repr__(self) -> str:
  153. return (f"RequestOutput(request_id={self.request_id}, "
  154. f"prompt={self.prompt!r}, "
  155. f"prompt_token_ids={self.prompt_token_ids}, "
  156. f"encoder_prompt={self.encoder_prompt!r}, "
  157. f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
  158. f"prompt_logprobs={self.prompt_logprobs}, "
  159. f"outputs={self.outputs}, "
  160. f"finished={self.finished}, "
  161. f"metrics={self.metrics}, "
  162. f"lora_request={self.lora_request})")
  163. class EmbeddingRequestOutput:
  164. """
  165. The output data of an embedding request to the LLM.
  166. Args:
  167. request_id (str): A unique identifier for the embedding request.
  168. outputs (EmbeddingOutput): The embedding results for the given input.
  169. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
  170. finished (bool): A flag indicating whether the embedding is completed.
  171. """
  172. def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
  173. prompt_token_ids: List[int], finished: bool):
  174. self.request_id = request_id
  175. self.prompt_token_ids = prompt_token_ids
  176. self.finished = finished
  177. self.outputs = outputs
  178. @classmethod
  179. def from_seq_group(cls,
  180. seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
  181. if seq_group.embeddings is None:
  182. raise ValueError(
  183. "Embeddings are missing in seq_group for EmbeddingRequest.")
  184. output = EmbeddingOutput(seq_group.embeddings)
  185. prompt_token_ids = seq_group.prompt_token_ids
  186. finished = seq_group.is_finished()
  187. return cls(seq_group.request_id, output, prompt_token_ids, finished)
  188. def __repr__(self):
  189. """
  190. Returns a string representation of an EmbeddingRequestOutput instance.
  191. The representation includes the request_id and the number of outputs,
  192. providing a quick overview of the embedding request's results.
  193. Returns:
  194. str: A string representation of the EmbeddingRequestOutput instance.
  195. """
  196. return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
  197. f"outputs={repr(self.outputs)}, "
  198. f"prompt_token_ids={self.prompt_token_ids}, "
  199. f"finished={self.finished})")
  200. class RequestOutputFactory:
  201. @staticmethod
  202. def create(seq_group):
  203. # Determine the type based on a condition, for example:
  204. if hasattr(seq_group,
  205. 'embeddings') and seq_group.embeddings is not None:
  206. return EmbeddingRequestOutput.from_seq_group(seq_group)
  207. else:
  208. return RequestOutput.from_seq_group(seq_group)