1
0

outputs.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from typing import List, Optional, Union
  2. import time
  3. from aphrodite.common.sequence import (
  4. PromptLogprobs,
  5. SampleLogprobs,
  6. SequenceGroup,
  7. SequenceStatus,
  8. RequestMetrics,
  9. )
  10. from aphrodite.lora.request import LoRARequest
  11. class CompletionOutput:
  12. """The output data of one completion output of a request.
  13. Args:
  14. index: The index of the output in the request.
  15. text: The generated output text.
  16. token_ids: The token IDs of the generated output text.
  17. cumulative_logprob: The cumulative log probability of the generated
  18. output text.
  19. logprobs: The log probabilities of the top probability words at each
  20. position if the logprobs are requested.
  21. finish_reason: The reason why the sequence is finished.
  22. stop_reason: The stop string or token id that caused the completion
  23. to stop, None if the completion finished for some other reason
  24. including encountering the EOS token.
  25. lora_request: The LoRA request that was used to generate the output.
  26. """
  27. def __init__(
  28. self,
  29. index: int,
  30. text: str,
  31. token_ids: List[int],
  32. cumulative_logprob: float,
  33. logprobs: Optional[SampleLogprobs],
  34. finish_reason: Optional[str] = None,
  35. stop_reason: Union[int, str, None] = None,
  36. lora_request: Optional[LoRARequest] = None,
  37. ) -> None:
  38. self.index = index
  39. self.text = text
  40. self.token_ids = token_ids
  41. self.cumulative_logprob = cumulative_logprob
  42. self.logprobs = logprobs
  43. self.finish_reason = finish_reason
  44. self.stop_reason = stop_reason
  45. self.lora_request = lora_request
  46. def finished(self) -> bool:
  47. return self.finish_reason is not None
  48. def __repr__(self) -> str:
  49. return (f"CompletionOutput(index={self.index}, "
  50. f"text={self.text!r}, "
  51. f"token_ids={self.token_ids}, "
  52. f"cumulative_logprob={self.cumulative_logprob}, "
  53. f"logprobs={self.logprobs}, "
  54. f"finish_reason={self.finish_reason}, "
  55. f"stop_reason={self.stop_reason})")
  56. class RequestOutput:
  57. """The output data of a request to the LLM.
  58. Args:
  59. request_id: The unique ID of the request.
  60. prompt: The prompt string of the request.
  61. prompt_token_ids: The token IDs of the prompt.
  62. prompt_logprobs: The log probabilities to return per prompt token.
  63. outputs: The output sequences of the request.
  64. finished: Whether the whole request is finished.
  65. metrics: Metrics associated with the request.
  66. lora_request: The LoRA request that was used to generate the output.
  67. """
  68. def __init__(
  69. self,
  70. request_id: str,
  71. prompt: str,
  72. prompt_token_ids: List[int],
  73. prompt_logprobs: Optional[PromptLogprobs],
  74. outputs: List[CompletionOutput],
  75. finished: bool,
  76. metrics: Optional[RequestMetrics] = None,
  77. lora_request: Optional[LoRARequest] = None,
  78. ) -> None:
  79. self.request_id = request_id
  80. self.prompt = prompt
  81. self.prompt_token_ids = prompt_token_ids
  82. self.prompt_logprobs = prompt_logprobs
  83. self.outputs = outputs
  84. self.finished = finished
  85. self.metrics = metrics
  86. self.lora_request = lora_request
  87. @classmethod
  88. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  89. # Get the top-n sequences.
  90. n = seq_group.sampling_params.n
  91. seqs = seq_group.get_seqs()
  92. if n == 1:
  93. top_n_seqs = seqs
  94. else:
  95. if seq_group.sampling_params.use_beam_search:
  96. sorting_key = lambda seq: seq.get_beam_search_score(
  97. seq_group.sampling_params.length_penalty)
  98. else:
  99. sorting_key = lambda seq: seq.get_cumulative_logprob()
  100. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  101. top_n_seqs = sorted_seqs[:n]
  102. # Create the outputs.
  103. # NOTE: We need omit logprobs here explicitly because the sequence
  104. # always has the logprobs of the sampled tokens even if the
  105. # logprobs are not requested.
  106. include_logprobs = seq_group.sampling_params.logprobs is not None
  107. text_buffer_length = seq_group.sampling_params.output_text_buffer_length
  108. outputs = [
  109. CompletionOutput(seqs.index(seq),
  110. seq.get_output_text_to_return(text_buffer_length),
  111. seq.get_output_token_ids(),
  112. seq.get_cumulative_logprob(),
  113. seq.output_logprobs if include_logprobs else None,
  114. SequenceStatus.get_finished_reason(seq.status),
  115. seq.stop_reason) for seq in top_n_seqs
  116. ]
  117. # Every sequence in the sequence group should have the same prompt.
  118. prompt = seq_group.prompt
  119. prompt_token_ids = seq_group.prompt_token_ids
  120. prompt_logprobs = seq_group.prompt_logprobs
  121. finished = seq_group.is_finished()
  122. finished_time = time.time() if finished else None
  123. seq_group.set_finished_time(finished_time)
  124. return cls(
  125. seq_group.request_id,
  126. prompt,
  127. prompt_token_ids,
  128. prompt_logprobs,
  129. outputs,
  130. finished,
  131. seq_group.metrics,
  132. lora_request=seq_group.lora_request,
  133. )
  134. def __repr__(self) -> str:
  135. return (f"RequestOutput(request_id={self.request_id}, "
  136. f"prompt={self.prompt!r}, "
  137. f"prompt_token_ids={self.prompt_token_ids}, "
  138. f"prompt_logprobs={self.prompt_logprobs}, "
  139. f"outputs={self.outputs}, "
  140. f"finished={self.finished}, "
  141. f"metrics={self.metrics}, "
  142. f"lora_request={self.lora_request})")