1
0

outputs.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import time
  2. from dataclasses import dataclass
  3. from typing import List, Optional
  4. from typing import Sequence as GenericSequence
  5. from typing import Union
  6. from aphrodite.common.sampling_params import RequestOutputKind
  7. from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
  8. SampleLogprobs, SequenceGroup,
  9. SequenceStatus)
  10. from aphrodite.lora.request import LoRARequest
  11. @dataclass
  12. class CompletionOutput:
  13. """The output data of one completion output of a request.
  14. Args:
  15. index: The index of the output in the request.
  16. text: The generated output text.
  17. token_ids: The token IDs of the generated output text.
  18. cumulative_logprob: The cumulative log probability of the generated
  19. output text.
  20. logprobs: The log probabilities of the top probability words at each
  21. position if the logprobs are requested.
  22. finish_reason: The reason why the sequence is finished.
  23. stop_reason: The stop string or token id that caused the completion
  24. to stop, None if the completion finished for some other reason
  25. including encountering the EOS token.
  26. lora_request: The LoRA request that was used to generate the output.
  27. """
  28. index: int
  29. text: str
  30. token_ids: GenericSequence[int]
  31. cumulative_logprob: Optional[float]
  32. logprobs: Optional[SampleLogprobs]
  33. finish_reason: Optional[str] = None
  34. stop_reason: Union[int, str, None] = None
  35. lora_request: Optional[LoRARequest] = None
  36. def finished(self) -> bool:
  37. return self.finish_reason is not None
  38. def __repr__(self) -> str:
  39. return (f"CompletionOutput(index={self.index}, "
  40. f"text={self.text!r}, "
  41. f"token_ids={self.token_ids}, "
  42. f"cumulative_logprob={self.cumulative_logprob}, "
  43. f"logprobs={self.logprobs}, "
  44. f"finish_reason={self.finish_reason}, "
  45. f"stop_reason={self.stop_reason})")
  46. @dataclass
  47. class EmbeddingOutput:
  48. """The output data of one completion output of a request.
  49. Args:
  50. embedding: The embedding vector, which is a list of floats. The
  51. length of vector depends on the model as listed in the embedding guide.
  52. """
  53. embedding: List[float]
  54. def __repr__(self) -> str:
  55. return (f"EmbeddingOutput("
  56. f"embedding={len(self.embedding)})")
  57. class RequestOutput:
  58. """The output data of a completion request to the LLM.
  59. Args:
  60. request_id: The unique ID of the request.
  61. prompt: The prompt string of the request.
  62. For encoder/decoder models, this is the
  63. decoder input prompt.
  64. prompt_token_ids: The token IDs of the prompt.
  65. For encoder/decoder models, this is the
  66. decoder input prompt token ids.
  67. prompt_logprobs: The log probabilities to return per prompt token.
  68. outputs: The output sequences of the request.
  69. finished: Whether the whole request is finished.
  70. metrics: Metrics associated with the request.
  71. lora_request: The LoRA request that was used to generate the output.
  72. encoder_prompt: The encoder prompt string of the request;
  73. None if decoder-only
  74. encoder_prompt_token_ids: The token IDs of the encoder prompt;
  75. None if decoder-only
  76. """
  77. def __init__(
  78. self,
  79. request_id: str,
  80. prompt: Optional[str],
  81. prompt_token_ids: Optional[List[int]],
  82. prompt_logprobs: Optional[PromptLogprobs],
  83. outputs: List[CompletionOutput],
  84. finished: bool,
  85. metrics: Optional[RequestMetrics] = None,
  86. lora_request: Optional[LoRARequest] = None,
  87. encoder_prompt: Optional[str] = None,
  88. encoder_prompt_token_ids: Optional[List[int]] = None,
  89. ) -> None:
  90. self.request_id = request_id
  91. self.prompt = prompt
  92. self.prompt_token_ids = prompt_token_ids
  93. self.prompt_logprobs = prompt_logprobs
  94. self.outputs = outputs
  95. self.finished = finished
  96. self.metrics = metrics
  97. self.lora_request = lora_request
  98. self.encoder_prompt = encoder_prompt
  99. self.encoder_prompt_token_ids = encoder_prompt_token_ids
  100. @classmethod
  101. def from_seq_group(cls,
  102. seq_group: SequenceGroup) -> Optional["RequestOutput"]:
  103. sampling_params = seq_group.sampling_params
  104. if sampling_params is None:
  105. raise ValueError(
  106. "Sampling parameters are missing for a CompletionRequest.")
  107. finished = seq_group.is_finished()
  108. if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
  109. not finished):
  110. return None
  111. seqs = seq_group.get_seqs()
  112. if len(seqs) == 1:
  113. top_n_seqs = seqs
  114. else:
  115. # Get the top-n sequences.
  116. n = sampling_params.n
  117. if sampling_params.use_beam_search:
  118. sorting_key = lambda seq: seq.get_beam_search_score(
  119. sampling_params.length_penalty)
  120. else:
  121. sorting_key = lambda seq: seq.get_cumulative_logprob()
  122. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  123. top_n_seqs = sorted_seqs[:n]
  124. # Create the outputs.
  125. # NOTE: We need omit logprobs here explicitly because the sequence
  126. # always has the logprobs of the sampled tokens even if the
  127. # logprobs are not requested.
  128. include_logprobs = sampling_params.logprobs is not None
  129. text_buffer_length = sampling_params.output_text_buffer_length
  130. delta = sampling_params.output_kind == RequestOutputKind.DELTA
  131. outputs = []
  132. include_prompt = True
  133. for seq in top_n_seqs:
  134. output_text = seq.get_output_text_to_return(
  135. text_buffer_length, delta)
  136. output_token_ids = seq.get_output_token_ids_to_return(delta)
  137. output_logprobs = seq.output_logprobs if include_logprobs else None
  138. if delta:
  139. # Slice logprobs delta if applicable
  140. if output_logprobs:
  141. output_logprobs = output_logprobs[-len(output_token_ids):]
  142. # Don't include prompt if this is after the first output
  143. # containing decode token ids
  144. if include_prompt and seq.get_output_len() > len(
  145. output_token_ids):
  146. include_prompt = False
  147. outputs.append(
  148. CompletionOutput(
  149. seqs.index(seq), output_text, output_token_ids,
  150. seq.get_cumulative_logprob() if include_logprobs else None,
  151. output_logprobs,
  152. SequenceStatus.get_finished_reason(seq.status),
  153. seq.stop_reason))
  154. # Every sequence in the sequence group should have the same prompt.
  155. if include_prompt:
  156. prompt = seq_group.prompt
  157. prompt_token_ids = seq_group.prompt_token_ids
  158. encoder_prompt = seq_group.encoder_prompt
  159. encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
  160. prompt_logprobs = seq_group.prompt_logprobs
  161. else:
  162. prompt = None
  163. prompt_token_ids = None
  164. encoder_prompt = None
  165. encoder_prompt_token_ids = None
  166. prompt_logprobs = None
  167. finished_time = time.time() if finished else None
  168. seq_group.set_finished_time(finished_time)
  169. return cls(
  170. seq_group.request_id,
  171. prompt,
  172. prompt_token_ids,
  173. prompt_logprobs,
  174. outputs,
  175. finished,
  176. seq_group.metrics,
  177. lora_request=seq_group.lora_request,
  178. encoder_prompt=encoder_prompt,
  179. encoder_prompt_token_ids=encoder_prompt_token_ids,
  180. )
  181. def __repr__(self) -> str:
  182. return (f"RequestOutput(request_id={self.request_id}, "
  183. f"prompt={self.prompt!r}, "
  184. f"prompt_token_ids={self.prompt_token_ids}, "
  185. f"encoder_prompt={self.encoder_prompt!r}, "
  186. f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
  187. f"prompt_logprobs={self.prompt_logprobs}, "
  188. f"outputs={self.outputs}, "
  189. f"finished={self.finished}, "
  190. f"metrics={self.metrics}, "
  191. f"lora_request={self.lora_request})")
  192. class EmbeddingRequestOutput:
  193. """
  194. The output data of an embedding request to the LLM.
  195. Args:
  196. request_id (str): A unique identifier for the embedding request.
  197. outputs (EmbeddingOutput): The embedding results for the given input.
  198. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
  199. finished (bool): A flag indicating whether the embedding is completed.
  200. """
  201. def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
  202. prompt_token_ids: List[int], finished: bool):
  203. self.request_id = request_id
  204. self.prompt_token_ids = prompt_token_ids
  205. self.finished = finished
  206. self.outputs = outputs
  207. @classmethod
  208. def from_seq_group(cls,
  209. seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
  210. if seq_group.embeddings is None:
  211. raise ValueError(
  212. "Embeddings are missing in seq_group for EmbeddingRequest.")
  213. output = EmbeddingOutput(seq_group.embeddings)
  214. prompt_token_ids = seq_group.prompt_token_ids
  215. finished = seq_group.is_finished()
  216. return cls(seq_group.request_id, output, prompt_token_ids, finished)
  217. def __repr__(self):
  218. """
  219. Returns a string representation of an EmbeddingRequestOutput instance.
  220. The representation includes the request_id and the number of outputs,
  221. providing a quick overview of the embedding request's results.
  222. Returns:
  223. str: A string representation of the EmbeddingRequestOutput instance.
  224. """
  225. return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
  226. f"outputs={repr(self.outputs)}, "
  227. f"prompt_token_ids={self.prompt_token_ids}, "
  228. f"finished={self.finished})")
  229. class RequestOutputFactory:
  230. @staticmethod
  231. def create(seq_group):
  232. # Determine the type based on a condition, for example:
  233. if hasattr(seq_group,
  234. 'embeddings') and seq_group.embeddings is not None:
  235. return EmbeddingRequestOutput.from_seq_group(seq_group)
  236. else:
  237. return RequestOutput.from_seq_group(seq_group)