outputs.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import time
  2. from dataclasses import dataclass
  3. from typing import List, Optional
  4. from typing import Sequence as GenericSequence
  5. from typing import Union
  6. from aphrodite.common.sampling_params import RequestOutputKind
  7. from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
  8. SampleLogprobs, SequenceGroup,
  9. SequenceStatus)
  10. from aphrodite.lora.request import LoRARequest
  11. @dataclass
  12. class CompletionOutput:
  13. """The output data of one completion output of a request.
  14. Args:
  15. index: The index of the output in the request.
  16. text: The generated output text.
  17. token_ids: The token IDs of the generated output text.
  18. cumulative_logprob: The cumulative log probability of the generated
  19. output text.
  20. logprobs: The log probabilities of the top probability words at each
  21. position if the logprobs are requested.
  22. finish_reason: The reason why the sequence is finished.
  23. stop_reason: The stop string or token id that caused the completion
  24. to stop, None if the completion finished for some other reason
  25. including encountering the EOS token.
  26. lora_request: The LoRA request that was used to generate the output.
  27. """
  28. index: int
  29. text: str
  30. token_ids: GenericSequence[int]
  31. cumulative_logprob: Optional[float]
  32. logprobs: Optional[SampleLogprobs]
  33. finish_reason: Optional[str] = None
  34. stop_reason: Union[int, str, None] = None
  35. lora_request: Optional[LoRARequest] = None
  36. def finished(self) -> bool:
  37. return self.finish_reason is not None
  38. def __repr__(self) -> str:
  39. return (f"CompletionOutput(index={self.index}, "
  40. f"text={self.text!r}, "
  41. f"token_ids={self.token_ids}, "
  42. f"cumulative_logprob={self.cumulative_logprob}, "
  43. f"logprobs={self.logprobs}, "
  44. f"finish_reason={self.finish_reason}, "
  45. f"stop_reason={self.stop_reason})")
  46. @dataclass
  47. class EmbeddingOutput:
  48. """The output data of one completion output of a request.
  49. Args:
  50. embedding: The embedding vector, which is a list of floats. The
  51. length of vector depends on the model as listed in the embedding guide.
  52. """
  53. embedding: List[float]
  54. def __repr__(self) -> str:
  55. return (f"EmbeddingOutput("
  56. f"embedding={len(self.embedding)})")
  57. class RequestOutput:
  58. """The output data of a completion request to the LLM.
  59. Args:
  60. request_id: The unique ID of the request.
  61. prompt: The prompt string of the request.
  62. For encoder/decoder models, this is the
  63. decoder input prompt.
  64. prompt_token_ids: The token IDs of the prompt.
  65. For encoder/decoder models, this is the
  66. decoder input prompt token ids.
  67. prompt_logprobs: The log probabilities to return per prompt token.
  68. outputs: The output sequences of the request.
  69. finished: Whether the whole request is finished.
  70. metrics: Metrics associated with the request.
  71. lora_request: The LoRA request that was used to generate the output.
  72. encoder_prompt: The encoder prompt string of the request;
  73. None if decoder-only
  74. encoder_prompt_token_ids: The token IDs of the encoder prompt;
  75. None if decoder-only
  76. """
  77. def __init__(
  78. self,
  79. request_id: str,
  80. prompt: Optional[str],
  81. prompt_token_ids: Optional[List[int]],
  82. prompt_logprobs: Optional[PromptLogprobs],
  83. outputs: List[CompletionOutput],
  84. finished: bool,
  85. metrics: Optional[RequestMetrics] = None,
  86. lora_request: Optional[LoRARequest] = None,
  87. encoder_prompt: Optional[str] = None,
  88. encoder_prompt_token_ids: Optional[List[int]] = None,
  89. ) -> None:
  90. self.request_id = request_id
  91. self.prompt = prompt
  92. self.prompt_token_ids = prompt_token_ids
  93. self.prompt_logprobs = prompt_logprobs
  94. self.outputs = outputs
  95. self.finished = finished
  96. self.metrics = metrics
  97. self.lora_request = lora_request
  98. self.encoder_prompt = encoder_prompt
  99. self.encoder_prompt_token_ids = encoder_prompt_token_ids
  100. @classmethod
  101. def from_seq_group(cls, seq_group: SequenceGroup,
  102. use_cache: bool) -> Optional["RequestOutput"]:
  103. sampling_params = seq_group.sampling_params
  104. if sampling_params is None:
  105. raise ValueError(
  106. "Sampling parameters are missing for a CompletionRequest.")
  107. finished = seq_group.is_finished()
  108. if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
  109. not finished):
  110. return None
  111. # Init cache (if needed)
  112. if use_cache and seq_group.cached_request_output is None:
  113. seq_group.cached_request_output = RequestOutput( # type: ignore
  114. request_id="",
  115. prompt=None,
  116. prompt_token_ids=[],
  117. prompt_logprobs=None,
  118. outputs=[],
  119. finished=False)
  120. seqs = seq_group.get_seqs()
  121. if len(seqs) == 1:
  122. top_n_seqs = seqs
  123. else:
  124. # Get the top-n sequences.
  125. n = sampling_params.n
  126. if sampling_params.use_beam_search:
  127. sorting_key = lambda seq: seq.get_beam_search_score(
  128. sampling_params.length_penalty)
  129. else:
  130. sorting_key = lambda seq: seq.get_cumulative_logprob()
  131. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  132. top_n_seqs = sorted_seqs[:n]
  133. # Create the outputs.
  134. # NOTE: We need omit logprobs here explicitly because the sequence
  135. # always has the logprobs of the sampled tokens even if the
  136. # logprobs are not requested.
  137. include_logprobs = sampling_params.logprobs is not None
  138. text_buffer_length = sampling_params.output_text_buffer_length
  139. delta = sampling_params.output_kind == RequestOutputKind.DELTA
  140. outputs = []
  141. include_prompt = True
  142. for i, seq in enumerate(top_n_seqs):
  143. output_text = seq.get_output_text_to_return(
  144. text_buffer_length, delta)
  145. output_token_ids = seq.get_output_token_ids_to_return(delta)
  146. num_output_tokens = 1 if isinstance(output_token_ids,
  147. int) else len(output_token_ids)
  148. output_logprobs = seq.output_logprobs if include_logprobs else None
  149. if delta:
  150. # Slice logprobs delta if applicable
  151. if output_logprobs:
  152. output_logprobs = output_logprobs[-num_output_tokens:]
  153. # Don't include prompt if this is after the first output
  154. # containing decode token ids
  155. if include_prompt and seq.get_output_len() > num_output_tokens:
  156. include_prompt = False
  157. if use_cache:
  158. # Get cached output object
  159. cached_outputs = seq_group.cached_request_output.outputs # type: ignore
  160. if i >= len(cached_outputs):
  161. cached_outputs.append(
  162. CompletionOutput(index=i,
  163. text="",
  164. token_ids=[],
  165. cumulative_logprob=None,
  166. logprobs=None,
  167. finish_reason=None,
  168. stop_reason=None))
  169. output = cached_outputs[i]
  170. # Init cached output object
  171. assert output.index == i
  172. output.text = output_text
  173. if isinstance(output_token_ids, int):
  174. output.token_ids.clear()
  175. output.token_ids.append(output_token_ids)
  176. else:
  177. output.token_ids = output_token_ids
  178. output.cumulative_logprob = seq.get_cumulative_logprob() \
  179. if include_logprobs else None
  180. output.logprobs = output_logprobs
  181. output.finish_reason = SequenceStatus.get_finished_reason(
  182. seq.status)
  183. output.stop_reason = seq.stop_reason
  184. else:
  185. output = CompletionOutput(
  186. seqs.index(seq), output_text, [output_token_ids]
  187. if isinstance(output_token_ids, int) else output_token_ids,
  188. seq.get_cumulative_logprob() if include_logprobs else None,
  189. output_logprobs,
  190. SequenceStatus.get_finished_reason(seq.status),
  191. seq.stop_reason)
  192. outputs.append(output)
  193. # Every sequence in the sequence group should have the same prompt.
  194. if include_prompt:
  195. prompt = seq_group.prompt
  196. prompt_token_ids = seq_group.prompt_token_ids
  197. encoder_prompt = seq_group.encoder_prompt
  198. encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
  199. prompt_logprobs = seq_group.prompt_logprobs
  200. else:
  201. prompt = None
  202. prompt_token_ids = None
  203. encoder_prompt = None
  204. encoder_prompt_token_ids = None
  205. prompt_logprobs = None
  206. finished_time = time.time() if finished else None
  207. seq_group.set_finished_time(finished_time)
  208. init_args = (seq_group.request_id, prompt, prompt_token_ids,
  209. prompt_logprobs, outputs, finished, seq_group.metrics,
  210. seq_group.lora_request, encoder_prompt,
  211. encoder_prompt_token_ids)
  212. if use_cache:
  213. request_output = seq_group.cached_request_output
  214. request_output.__init__(*init_args) # type: ignore
  215. else:
  216. request_output = cls(*init_args)
  217. return request_output
  218. def __repr__(self) -> str:
  219. return (f"RequestOutput(request_id={self.request_id}, "
  220. f"prompt={self.prompt!r}, "
  221. f"prompt_token_ids={self.prompt_token_ids}, "
  222. f"encoder_prompt={self.encoder_prompt!r}, "
  223. f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
  224. f"prompt_logprobs={self.prompt_logprobs}, "
  225. f"outputs={self.outputs}, "
  226. f"finished={self.finished}, "
  227. f"metrics={self.metrics}, "
  228. f"lora_request={self.lora_request})")
  229. class EmbeddingRequestOutput:
  230. """
  231. The output data of an embedding request to the LLM.
  232. Args:
  233. request_id (str): A unique identifier for the embedding request.
  234. outputs (EmbeddingOutput): The embedding results for the given input.
  235. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
  236. finished (bool): A flag indicating whether the embedding is completed.
  237. """
  238. def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
  239. prompt_token_ids: List[int], finished: bool):
  240. self.request_id = request_id
  241. self.prompt_token_ids = prompt_token_ids
  242. self.finished = finished
  243. self.outputs = outputs
  244. @classmethod
  245. def from_seq_group(cls,
  246. seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
  247. if seq_group.embeddings is None:
  248. raise ValueError(
  249. "Embeddings are missing in seq_group for EmbeddingRequest.")
  250. output = EmbeddingOutput(seq_group.embeddings)
  251. prompt_token_ids = seq_group.prompt_token_ids
  252. finished = seq_group.is_finished()
  253. return cls(seq_group.request_id, output, prompt_token_ids, finished)
  254. def __repr__(self):
  255. """
  256. Returns a string representation of an EmbeddingRequestOutput instance.
  257. The representation includes the request_id and the number of outputs,
  258. providing a quick overview of the embedding request's results.
  259. Returns:
  260. str: A string representation of the EmbeddingRequestOutput instance.
  261. """
  262. return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
  263. f"outputs={repr(self.outputs)}, "
  264. f"prompt_token_ids={self.prompt_token_ids}, "
  265. f"finished={self.finished})")
  266. class RequestOutputFactory:
  267. @staticmethod
  268. def create(seq_group: SequenceGroup, use_cache: bool = False):
  269. # Determine the type based on a condition, for example:
  270. if hasattr(seq_group,
  271. 'embeddings') and seq_group.embeddings is not None:
  272. return EmbeddingRequestOutput.from_seq_group(seq_group)
  273. else:
  274. return RequestOutput.from_seq_group(seq_group, use_cache)