outputs.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from typing import List, Optional
  2. from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
  3. SequenceGroup, SequenceStatus)
  4. from aphrodite.lora.request import LoRARequest
  5. class CompletionOutput:
  6. """The output data of one completion output of a request.
  7. Args:
  8. index: The index of the output in the request.
  9. text: The generated output text.
  10. token_ids: The token IDs of the generated output text.
  11. cumulative_logprob: The cumulative log probability of the generated
  12. output text.
  13. logprobs: The log probabilities of the top probability words at each
  14. position if the logprobs are requested.
  15. finish_reason: The reason why the sequence is finished.
  16. lora_request: The LoRA request that was used to generate the output.
  17. """
  18. def __init__(
  19. self,
  20. index: int,
  21. text: str,
  22. token_ids: List[int],
  23. cumulative_logprob: float,
  24. logprobs: Optional[SampleLogprobs],
  25. finish_reason: Optional[str] = None,
  26. lora_request: Optional[LoRARequest] = None,
  27. ) -> None:
  28. self.index = index
  29. self.text = text
  30. self.token_ids = token_ids
  31. self.cumulative_logprob = cumulative_logprob
  32. self.logprobs = logprobs
  33. self.finish_reason = finish_reason
  34. self.lora_request = lora_request
  35. def finished(self) -> bool:
  36. return self.finish_reason is not None
  37. def __repr__(self) -> str:
  38. return (f"CompletionOutput(index={self.index}, "
  39. f"text={self.text!r}, "
  40. f"token_ids={self.token_ids}, "
  41. f"cumulative_logprob={self.cumulative_logprob}, "
  42. f"logprobs={self.logprobs}, "
  43. f"finish_reason={self.finish_reason})")
  44. class RequestOutput:
  45. """The output data of a request to the LLM.
  46. Args:
  47. request_id: The unique ID of the request.
  48. prompt: The prompt string of the request.
  49. prompt_token_ids: The token IDs of the prompt.
  50. prompt_logprobs: The log probabilities to return per prompt token.
  51. outputs: The output sequences of the request.
  52. finished: Whether the whole request is finished.
  53. lora_request: The LoRA request that was used to generate the output.
  54. """
  55. def __init__(
  56. self,
  57. request_id: str,
  58. prompt: str,
  59. prompt_token_ids: List[int],
  60. prompt_logprobs: Optional[PromptLogprobs],
  61. outputs: List[CompletionOutput],
  62. finished: bool,
  63. lora_request: Optional[LoRARequest] = None,
  64. ) -> None:
  65. self.request_id = request_id
  66. self.prompt = prompt
  67. self.prompt_token_ids = prompt_token_ids
  68. self.prompt_logprobs = prompt_logprobs
  69. self.outputs = outputs
  70. self.finished = finished
  71. self.lora_request = lora_request
  72. @classmethod
  73. def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
  74. # Get the top-n sequences.
  75. n = seq_group.sampling_params.n
  76. seqs = seq_group.get_seqs()
  77. if seq_group.sampling_params.use_beam_search:
  78. sorting_key = lambda seq: seq.get_beam_search_score(
  79. seq_group.sampling_params.length_penalty)
  80. else:
  81. # ruff: noqa: E731
  82. sorting_key = lambda seq: seq.get_cumulative_logprob()
  83. sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
  84. top_n_seqs = sorted_seqs[:n]
  85. # Create the outputs.
  86. outputs: List[CompletionOutput] = []
  87. for seq in top_n_seqs:
  88. logprobs = seq.output_logprobs
  89. if seq_group.sampling_params.logprobs is None:
  90. # NOTE: We need to take care of this case because the sequence
  91. # always has the logprobs of the sampled tokens even if the
  92. # logprobs are not requested.
  93. logprobs = None
  94. finshed_reason = SequenceStatus.get_finished_reason(seq.status)
  95. output = CompletionOutput(seqs.index(seq), seq.output_text,
  96. seq.get_output_token_ids(),
  97. seq.get_cumulative_logprob(), logprobs,
  98. finshed_reason)
  99. outputs.append(output)
  100. # Every sequence in the sequence group should have the same prompt.
  101. prompt = seq_group.prompt
  102. prompt_token_ids = seq_group.prompt_token_ids
  103. prompt_logprobs = seq_group.prompt_logprobs
  104. finished = seq_group.is_finished()
  105. return cls(seq_group.request_id,
  106. prompt,
  107. prompt_token_ids,
  108. prompt_logprobs,
  109. outputs,
  110. finished,
  111. lora_request=seq_group.lora_request)
  112. def __repr__(self) -> str:
  113. return (f"RequestOutput(request_id={self.request_id}, "
  114. f"prompt={self.prompt!r}, "
  115. f"prompt_token_ids={self.prompt_token_ids}, "
  116. f"prompt_logprobs={self.prompt_logprobs}, "
  117. f"outputs={self.outputs}, "
  118. f"finished={self.finished}, "
  119. f"lora_request={self.lora_request})")