outputs.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from typing import List, Optional
  2. from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
  3. SequenceGroup, SequenceStatus)
  4. from aphrodite.lora.request import LoRARequest
  5. class CompletionOutput:
  6. """The output data of one completion output of a request.
  7. Args:
  8. index: The index of the output in the request.
  9. text: The generated output text.
  10. token_ids: The token IDs of the generated output text.
  11. cumulative_logprob: The cumulative log probability of the generated
  12. output text.
  13. logprobs: The log probabilities of the top probability words at each
  14. position if the logprobs are requested.
  15. finish_reason: The reason why the sequence is finished.
  16. lora_request: The LoRA request that was used to generate the output.
  17. """
  18. def __init__(
  19. self,
  20. index: int,
  21. text: str,
  22. token_ids: List[int],
  23. cumulative_logprob: float,
  24. logprobs: Optional[SampleLogprobs],
  25. finish_reason: Optional[str] = None,
  26. lora_request: Optional[LoRARequest] = None,
  27. ) -> None:
  28. self.index = index
  29. self.text = text
  30. self.token_ids = token_ids
  31. self.cumulative_logprob = cumulative_logprob
  32. self.logprobs = logprobs
  33. self.finish_reason = finish_reason
  34. self.lora_request = lora_request
  35. def finished(self) -> bool:
  36. return self.finish_reason is not None
  37. def __repr__(self) -> str:
  38. return (f"CompletionOutput(index={self.index}, "
  39. f"text={self.text!r}, "
  40. f"token_ids={self.token_ids}, "
  41. f"cumulative_logprob={self.cumulative_logprob}, "
  42. f"logprobs={self.logprobs}, "
  43. f"finish_reason={self.finish_reason})")
  44. class RequestOutput:
  45. """The output data of a request to the LLM.
  46. Args:
  47. request_id: The unique ID of the request.
  48. prompt: The prompt string of the request.
  49. prompt_token_ids: The token IDs of the prompt.
  50. prompt_logprobs: The log probabilities to return per prompt token.
  51. outputs: The output sequences of the request.
  52. finished: Whether the whole request is finished.
  53. lora_request: The LoRA request that was used to generate the output.
  54. """
  55. def __init__(
  56. self,
  57. request_id: str,
  58. prompt: str,
  59. prompt_token_ids: List[int],
  60. prompt_logprobs: Optional[PromptLogprobs],
  61. outputs: List[CompletionOutput],
  62. finished: bool,
  63. lora_request: Optional[LoRARequest] = None,
  64. ) -> None:
  65. self.request_id = request_id
  66. self.prompt = prompt
  67. self.prompt_token_ids = prompt_token_ids
  68. self.prompt_logprobs = prompt_logprobs
  69. self.outputs = outputs
  70. self.finished = finished
  71. self.lora_request = lora_request
  72. @classmethod
  73. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  74. # Get the top-n sequences.
  75. n = seq_group.sampling_params.n
  76. seqs = seq_group.get_seqs()
  77. if seq_group.sampling_params.use_beam_search:
  78. sorting_key = lambda seq: seq.get_beam_search_score(
  79. seq_group.sampling_params.length_penalty)
  80. else:
  81. sorting_key = lambda seq: seq.get_cumulative_logprob()
  82. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  83. top_n_seqs = sorted_seqs[:n]
  84. # Create the outputs.
  85. outputs: List[CompletionOutput] = []
  86. for seq in top_n_seqs:
  87. logprobs = seq.output_logprobs
  88. if seq_group.sampling_params.logprobs is None:
  89. # NOTE: We need to take care of this case because the sequence
  90. # always has the logprobs of the sampled tokens even if the
  91. # logprobs are not requested.
  92. logprobs = None
  93. finshed_reason = SequenceStatus.get_finished_reason(seq.status)
  94. output = CompletionOutput(seqs.index(seq), seq.output_text,
  95. seq.get_output_token_ids(),
  96. seq.get_cumulative_logprob(), logprobs,
  97. finshed_reason)
  98. outputs.append(output)
  99. # Every sequence in the sequence group should have the same prompt.
  100. prompt = seq_group.prompt
  101. prompt_token_ids = seq_group.prompt_token_ids
  102. prompt_logprobs = seq_group.prompt_logprobs
  103. finished = seq_group.is_finished()
  104. return cls(seq_group.request_id,
  105. prompt,
  106. prompt_token_ids,
  107. prompt_logprobs,
  108. outputs,
  109. finished,
  110. lora_request=seq_group.lora_request)
  111. def __repr__(self) -> str:
  112. return (f"RequestOutput(request_id={self.request_id}, "
  113. f"prompt={self.prompt!r}, "
  114. f"prompt_token_ids={self.prompt_token_ids}, "
  115. f"prompt_logprobs={self.prompt_logprobs}, "
  116. f"outputs={self.outputs}, "
  117. f"finished={self.finished}, "
  118. f"lora_request={self.lora_request})")