outputs.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import time
  2. from dataclasses import dataclass
  3. from typing import List, Optional
  4. from typing import Sequence as GenericSequence
  5. from typing import Union
  6. from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
  7. SampleLogprobs, SequenceGroup,
  8. SequenceStatus)
  9. from aphrodite.lora.request import LoRARequest
  10. @dataclass
  11. class CompletionOutput:
  12. """The output data of one completion output of a request.
  13. Args:
  14. index: The index of the output in the request.
  15. text: The generated output text.
  16. token_ids: The token IDs of the generated output text.
  17. cumulative_logprob: The cumulative log probability of the generated
  18. output text.
  19. logprobs: The log probabilities of the top probability words at each
  20. position if the logprobs are requested.
  21. finish_reason: The reason why the sequence is finished.
  22. stop_reason: The stop string or token id that caused the completion
  23. to stop, None if the completion finished for some other reason
  24. including encountering the EOS token.
  25. lora_request: The LoRA request that was used to generate the output.
  26. """
  27. index: int
  28. text: str
  29. token_ids: GenericSequence[int]
  30. cumulative_logprob: Optional[float]
  31. logprobs: Optional[SampleLogprobs]
  32. finish_reason: Optional[str] = None
  33. stop_reason: Union[int, str, None] = None
  34. lora_request: Optional[LoRARequest] = None
  35. def finished(self) -> bool:
  36. return self.finish_reason is not None
  37. def __repr__(self) -> str:
  38. return (f"CompletionOutput(index={self.index}, "
  39. f"text={self.text!r}, "
  40. f"token_ids={self.token_ids}, "
  41. f"cumulative_logprob={self.cumulative_logprob}, "
  42. f"logprobs={self.logprobs}, "
  43. f"finish_reason={self.finish_reason}, "
  44. f"stop_reason={self.stop_reason})")
  45. @dataclass
  46. class EmbeddingOutput:
  47. """The output data of one completion output of a request.
  48. Args:
  49. embedding: The embedding vector, which is a list of floats. The
  50. length of vector depends on the model as listed in the embedding guide.
  51. """
  52. embedding: List[float]
  53. def __repr__(self) -> str:
  54. return (f"EmbeddingOutput("
  55. f"embedding={len(self.embedding)})")
  56. class RequestOutput:
  57. """The output data of a completion request to the LLM.
  58. Args:
  59. request_id: The unique ID of the request.
  60. prompt: The prompt string of the request.
  61. For encoder/decoder models, this is the
  62. decoder input prompt.
  63. prompt_token_ids: The token IDs of the prompt.
  64. For encoder/decoder models, this is the
  65. decoder input prompt token ids.
  66. prompt_logprobs: The log probabilities to return per prompt token.
  67. outputs: The output sequences of the request.
  68. finished: Whether the whole request is finished.
  69. metrics: Metrics associated with the request.
  70. lora_request: The LoRA request that was used to generate the output.
  71. encoder_prompt: The encoder prompt string of the request;
  72. None if decoder-only
  73. encoder_prompt_token_ids: The token IDs of the encoder prompt;
  74. None if decoder-only
  75. """
  76. def __init__(
  77. self,
  78. request_id: str,
  79. prompt: Optional[str],
  80. prompt_token_ids: List[int],
  81. prompt_logprobs: Optional[PromptLogprobs],
  82. outputs: List[CompletionOutput],
  83. finished: bool,
  84. metrics: Optional[RequestMetrics] = None,
  85. lora_request: Optional[LoRARequest] = None,
  86. encoder_prompt: Optional[str] = None,
  87. encoder_prompt_token_ids: Optional[List[int]] = None,
  88. ) -> None:
  89. self.request_id = request_id
  90. self.prompt = prompt
  91. self.prompt_token_ids = prompt_token_ids
  92. self.prompt_logprobs = prompt_logprobs
  93. self.outputs = outputs
  94. self.finished = finished
  95. self.metrics = metrics
  96. self.lora_request = lora_request
  97. self.encoder_prompt = encoder_prompt
  98. self.encoder_prompt_token_ids = encoder_prompt_token_ids
  99. @classmethod
  100. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  101. if seq_group.sampling_params is None:
  102. raise ValueError(
  103. "Sampling parameters are missing for a CompletionRequest.")
  104. # Get the top-n sequences.
  105. n = seq_group.sampling_params.n
  106. seqs = seq_group.get_seqs()
  107. if n == 1:
  108. top_n_seqs = seqs
  109. else:
  110. if seq_group.sampling_params.use_beam_search:
  111. sorting_key = lambda seq: seq.get_beam_search_score(
  112. seq_group.sampling_params.length_penalty)
  113. else:
  114. sorting_key = lambda seq: seq.get_cumulative_logprob()
  115. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  116. top_n_seqs = sorted_seqs[:n]
  117. # Create the outputs.
  118. # NOTE: We need omit logprobs here explicitly because the sequence
  119. # always has the logprobs of the sampled tokens even if the
  120. # logprobs are not requested.
  121. include_logprobs = seq_group.sampling_params.logprobs is not None
  122. text_buffer_length = seq_group.sampling_params.output_text_buffer_length
  123. outputs = [
  124. CompletionOutput(
  125. seqs.index(seq),
  126. seq.get_output_text_to_return(text_buffer_length),
  127. seq.data._output_token_ids,
  128. seq.get_cumulative_logprob() if include_logprobs else None,
  129. seq.output_logprobs if include_logprobs else None,
  130. SequenceStatus.get_finished_reason(seq.status),
  131. seq.stop_reason) for seq in top_n_seqs
  132. ]
  133. # Every sequence in the sequence group should have the same prompt.
  134. prompt = seq_group.prompt
  135. prompt_token_ids = seq_group.prompt_token_ids
  136. encoder_prompt = seq_group.encoder_prompt
  137. encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
  138. prompt_logprobs = seq_group.prompt_logprobs
  139. finished = seq_group.is_finished()
  140. finished_time = time.time() if finished else None
  141. seq_group.set_finished_time(finished_time)
  142. return cls(
  143. seq_group.request_id,
  144. prompt,
  145. prompt_token_ids,
  146. prompt_logprobs,
  147. outputs,
  148. finished,
  149. seq_group.metrics,
  150. lora_request=seq_group.lora_request,
  151. encoder_prompt=encoder_prompt,
  152. encoder_prompt_token_ids=encoder_prompt_token_ids,
  153. )
  154. def __repr__(self) -> str:
  155. return (f"RequestOutput(request_id={self.request_id}, "
  156. f"prompt={self.prompt!r}, "
  157. f"prompt_token_ids={self.prompt_token_ids}, "
  158. f"encoder_prompt={self.encoder_prompt!r}, "
  159. f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
  160. f"prompt_logprobs={self.prompt_logprobs}, "
  161. f"outputs={self.outputs}, "
  162. f"finished={self.finished}, "
  163. f"metrics={self.metrics}, "
  164. f"lora_request={self.lora_request})")
  165. class EmbeddingRequestOutput:
  166. """
  167. The output data of an embedding request to the LLM.
  168. Args:
  169. request_id (str): A unique identifier for the embedding request.
  170. outputs (EmbeddingOutput): The embedding results for the given input.
  171. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
  172. finished (bool): A flag indicating whether the embedding is completed.
  173. """
  174. def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
  175. prompt_token_ids: List[int], finished: bool):
  176. self.request_id = request_id
  177. self.prompt_token_ids = prompt_token_ids
  178. self.finished = finished
  179. self.outputs = outputs
  180. @classmethod
  181. def from_seq_group(cls,
  182. seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
  183. if seq_group.embeddings is None:
  184. raise ValueError(
  185. "Embeddings are missing in seq_group for EmbeddingRequest.")
  186. output = EmbeddingOutput(seq_group.embeddings)
  187. prompt_token_ids = seq_group.prompt_token_ids
  188. finished = seq_group.is_finished()
  189. return cls(seq_group.request_id, output, prompt_token_ids, finished)
  190. def __repr__(self):
  191. """
  192. Returns a string representation of an EmbeddingRequestOutput instance.
  193. The representation includes the request_id and the number of outputs,
  194. providing a quick overview of the embedding request's results.
  195. Returns:
  196. str: A string representation of the EmbeddingRequestOutput instance.
  197. """
  198. return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
  199. f"outputs={repr(self.outputs)}, "
  200. f"prompt_token_ids={self.prompt_token_ids}, "
  201. f"finished={self.finished})")
  202. class RequestOutputFactory:
  203. @staticmethod
  204. def create(seq_group):
  205. # Determine the type based on a condition, for example:
  206. if hasattr(seq_group,
  207. 'embeddings') and seq_group.embeddings is not None:
  208. return EmbeddingRequestOutput.from_seq_group(seq_group)
  209. else:
  210. return RequestOutput.from_seq_group(seq_group)