outputs.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from typing import List, Optional
  2. from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
  3. SequenceGroup, SequenceStatus)
  4. class CompletionOutput:
  5. """The output data of one completion output of a request.
  6. Args:
  7. index: The index of the output in the request.
  8. text: The generated output text.
  9. token_ids: The token IDs of the generated output text.
  10. cumulative_logprob: The cumulative log probability of the generated
  11. output text.
  12. logprobs: The log probabilities of the top probability words at each
  13. position if the logprobs are requested.
  14. finish_reason: The reason why the sequence is finished.
  15. """
  16. def __init__(
  17. self,
  18. index: int,
  19. text: str,
  20. token_ids: List[int],
  21. cumulative_logprob: float,
  22. logprobs: Optional[SampleLogprobs],
  23. finish_reason: Optional[str] = None,
  24. ) -> None:
  25. self.index = index
  26. self.text = text
  27. self.token_ids = token_ids
  28. self.cumulative_logprob = cumulative_logprob
  29. self.logprobs = logprobs
  30. self.finish_reason = finish_reason
  31. def finished(self) -> bool:
  32. return self.finish_reason is not None
  33. def __repr__(self) -> str:
  34. return (f"CompletionOutput(index={self.index}, "
  35. f"text={self.text!r}, "
  36. f"token_ids={self.token_ids}, "
  37. f"cumulative_logprob={self.cumulative_logprob}, "
  38. f"logprobs={self.logprobs}, "
  39. f"finish_reason={self.finish_reason})")
  40. class RequestOutput:
  41. """The output data of a request to the LLM.
  42. Args:
  43. request_id: The unique ID of the request.
  44. prompt: The prompt string of the request.
  45. prompt_token_ids: The token IDs of the prompt.
  46. outputs: The output sequences of the request.
  47. finished: Whether the whole request is finished.
  48. """
  49. def __init__(
  50. self,
  51. request_id: str,
  52. prompt: str,
  53. prompt_token_ids: List[int],
  54. prompt_logprobs: Optional[PromptLogprobs],
  55. outputs: List[CompletionOutput],
  56. finished: bool,
  57. ) -> None:
  58. self.request_id = request_id
  59. self.prompt = prompt
  60. self.prompt_token_ids = prompt_token_ids
  61. self.prompt_logprobs = prompt_logprobs
  62. self.outputs = outputs
  63. self.finished = finished
  64. @classmethod
  65. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  66. # Get the top-n sequences.
  67. n = seq_group.sampling_params.n
  68. seqs = seq_group.get_seqs()
  69. if seq_group.sampling_params.use_beam_search:
  70. sorting_key = lambda seq: seq.get_beam_search_score(
  71. seq_group.sampling_params.length_penalty)
  72. else:
  73. sorting_key = lambda seq: seq.get_cumulative_logprob()
  74. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  75. top_n_seqs = sorted_seqs[:n]
  76. # Create the outputs.
  77. outputs: List[CompletionOutput] = []
  78. for seq in top_n_seqs:
  79. logprobs = seq.output_logprobs
  80. if seq_group.sampling_params.logprobs is None:
  81. # NOTE: We need to take care of this case because the sequence
  82. # always has the logprobs of the sampled tokens even if the
  83. # logprobs are not requested.
  84. logprobs = None
  85. finshed_reason = SequenceStatus.get_finished_reason(seq.status)
  86. output = CompletionOutput(seqs.index(seq), seq.output_text,
  87. seq.get_output_token_ids(),
  88. seq.get_cumulative_logprob(), logprobs,
  89. finshed_reason)
  90. outputs.append(output)
  91. # Every sequence in the sequence group should have the same prompt.
  92. prompt = seq_group.prompt
  93. prompt_token_ids = seq_group.prompt_token_ids
  94. prompt_logprobs = seq_group.prompt_logprobs
  95. finished = seq_group.is_finished()
  96. return cls(seq_group.request_id, prompt, prompt_token_ids,
  97. prompt_logprobs, outputs, finished)
  98. def __repr__(self) -> str:
  99. return (f"RequestOutput(request_id={self.request_id}, "
  100. f"prompt={self.prompt!r}, "
  101. f"prompt_token_ids={self.prompt_token_ids}, "
  102. f"prompt_logprobs={self.prompt_logprobs}, "
  103. f"outputs={self.outputs}, "
  104. f"finished={self.finished})")