sequence.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from typing import Dict, List, Optional, Union
  5. from aphrodite.common.block import LogicalTokenBlock
  6. from aphrodite.common.sampling_params import SamplingParams
  7. PromptLogprobs = List[Optional[Dict[int, float]]]
  8. SampleLogprobs = List[Dict[int, float]]
  9. class SequenceStatus(enum.Enum):
  10. """Status of a sequence."""
  11. WAITING = enum.auto()
  12. RUNNING = enum.auto()
  13. SWAPPED = enum.auto()
  14. FINISHED_STOPPED = enum.auto()
  15. FINISHED_LENGTH_CAPPED = enum.auto()
  16. FINISHED_ABORTED = enum.auto()
  17. FINISHED_IGNORED = enum.auto()
  18. @staticmethod
  19. def is_finished(status: "SequenceStatus") -> bool:
  20. return status in [
  21. SequenceStatus.FINISHED_STOPPED,
  22. SequenceStatus.FINISHED_LENGTH_CAPPED,
  23. SequenceStatus.FINISHED_ABORTED,
  24. SequenceStatus.FINISHED_IGNORED,
  25. ]
  26. @staticmethod
  27. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  28. if status == SequenceStatus.FINISHED_STOPPED:
  29. finish_reason = "stop"
  30. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  31. finish_reason = "length"
  32. elif status == SequenceStatus.FINISHED_ABORTED:
  33. finish_reason = "abort"
  34. elif status == SequenceStatus.FINISHED_IGNORED:
  35. finish_reason = "length"
  36. else:
  37. finish_reason = None
  38. return finish_reason
  39. class SequenceData:
  40. """Data associated with a sequence.
  41. Args:
  42. prompt_token_ids: The token IDs of the prompt.
  43. Attributes:
  44. prompt_token_ids: The token IDs of the prompt.
  45. output_token_ids: The token IDs of the output.
  46. cumulative_logprob: The cumulative log probability of the output.
  47. """
  48. def __init__(
  49. self,
  50. prompt_token_ids: List[int],
  51. ) -> None:
  52. self.prompt_token_ids = prompt_token_ids
  53. self.output_token_ids: List[int] = []
  54. self.cumulative_logprob = 0.0
  55. def append_token_id(self, token_id: int, logprob: float) -> None:
  56. self.output_token_ids.append(token_id)
  57. self.cumulative_logprob += logprob
  58. def get_len(self) -> int:
  59. return len(self.output_token_ids) + len(self.prompt_token_ids)
  60. def get_prompt_len(self) -> int:
  61. return len(self.prompt_token_ids)
  62. def get_output_len(self) -> int:
  63. return len(self.output_token_ids)
  64. def get_token_ids(self) -> List[int]:
  65. return self.prompt_token_ids + self.output_token_ids
  66. def get_last_token_id(self) -> int:
  67. if not self.output_token_ids:
  68. return self.prompt_token_ids[-1]
  69. return self.output_token_ids[-1]
  70. def __repr__(self) -> str:
  71. return (f"SequenceData("
  72. f"prompt_token_ids={self.prompt_token_ids}, "
  73. f"output_token_ids={self.output_token_ids}, "
  74. f"cumulative_logprob={self.cumulative_logprob})")
  75. class Sequence:
  76. """Stores the data, status, and block information of a sequence.
  77. Args:
  78. seq_id: The ID of the sequence.
  79. prompt: The prompt of the sequence.
  80. prompt_token_ids: The token IDs of the prompt.
  81. block_size: The block size of the sequence. Should be the same as the
  82. block size used by the block manager and cache engine.
  83. """
  84. def __init__(
  85. self,
  86. seq_id: int,
  87. prompt: str,
  88. prompt_token_ids: List[int],
  89. block_size: int,
  90. ) -> None:
  91. self.seq_id = seq_id
  92. self.prompt = prompt
  93. self.block_size = block_size
  94. self.data = SequenceData(prompt_token_ids)
  95. self.output_logprobs: SampleLogprobs = []
  96. self.output_text = ""
  97. self.logical_token_blocks: List[LogicalTokenBlock] = []
  98. # Initialize the logical token blocks with the prompt token ids.
  99. self._append_tokens_to_blocks(prompt_token_ids)
  100. self.status = SequenceStatus.WAITING
  101. # Used for incremental detokenization
  102. self.prefix_offset = 0
  103. self.read_offset = 0
  104. # Input + output tokens
  105. self.tokens: Optional[List[str]] = None
  106. self.persistent_data = {}
  107. def _append_logical_block(self) -> None:
  108. block = LogicalTokenBlock(
  109. block_number=len(self.logical_token_blocks),
  110. block_size=self.block_size,
  111. )
  112. self.logical_token_blocks.append(block)
  113. def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
  114. cursor = 0
  115. while cursor < len(token_ids):
  116. if not self.logical_token_blocks:
  117. self._append_logical_block()
  118. last_block = self.logical_token_blocks[-1]
  119. if last_block.is_full():
  120. self._append_logical_block()
  121. last_block = self.logical_token_blocks[-1]
  122. num_empty_slots = last_block.get_num_empty_slots()
  123. last_block.append_tokens(token_ids[cursor:cursor +
  124. num_empty_slots])
  125. cursor += num_empty_slots
  126. def append_token_id(
  127. self,
  128. token_id: int,
  129. logprobs: Dict[int, float],
  130. ) -> None:
  131. assert token_id in logprobs
  132. self._append_tokens_to_blocks([token_id])
  133. self.output_logprobs.append(logprobs)
  134. self.data.append_token_id(token_id, logprobs[token_id])
  135. def get_len(self) -> int:
  136. return self.data.get_len()
  137. def get_prompt_len(self) -> int:
  138. return self.data.get_prompt_len()
  139. def get_output_len(self) -> int:
  140. return self.data.get_output_len()
  141. def get_token_ids(self) -> List[int]:
  142. return self.data.get_token_ids()
  143. def get_last_token_id(self) -> int:
  144. return self.data.get_last_token_id()
  145. def get_output_token_ids(self) -> List[int]:
  146. return self.data.output_token_ids
  147. def get_cumulative_logprob(self) -> float:
  148. return self.data.cumulative_logprob
  149. def get_beam_search_score(self,
  150. length_penalty: float = 0.0,
  151. seq_len: Optional[int] = None,
  152. eos_token_id: Optional[int] = None) -> float:
  153. """Calculate the beam search score with length penalty.
  154. Adapted from
  155. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  156. """
  157. if seq_len is None:
  158. seq_len = self.get_len()
  159. # Note: HF implementation does not count the EOS token
  160. # towards the length, we align with that here for testing.
  161. if (eos_token_id is not None
  162. and self.get_last_token_id() == eos_token_id):
  163. seq_len -= 1
  164. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  165. def is_finished(self) -> bool:
  166. return SequenceStatus.is_finished(self.status)
  167. def fork(self, new_seq_id: int) -> "Sequence":
  168. new_seq = copy.deepcopy(self)
  169. new_seq.seq_id = new_seq_id
  170. return new_seq
  171. def __repr__(self) -> str:
  172. return (f"Sequence(seq_id={self.seq_id}, "
  173. f"status={self.status.name}, "
  174. f"num_blocks={len(self.logical_token_blocks)})")
  175. class SequenceGroup:
  176. """A group of sequences that are generated from the same prompt.
  177. Args:
  178. request_id: The ID of the request.
  179. seqs: The list of sequences.
  180. sampling_params: The sampling parameters used to generate the outputs.
  181. arrival_time: The arrival time of the request.
  182. """
  183. def __init__(
  184. self,
  185. request_id: str,
  186. seqs: List[Sequence],
  187. sampling_params: SamplingParams,
  188. arrival_time: float,
  189. ) -> None:
  190. self.request_id = request_id
  191. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  192. self.sampling_params = sampling_params
  193. self.arrival_time = arrival_time
  194. self.prompt_logprobs: Optional[PromptLogprobs] = None
  195. @property
  196. def prompt(self) -> str:
  197. # All sequences in the group should have the same prompt.
  198. # We use the prompt of an arbitrary sequence.
  199. return next(iter(self.seqs_dict.values())).prompt
  200. @property
  201. def prompt_token_ids(self) -> List[int]:
  202. # All sequences in the group should have the same prompt.
  203. # We use the prompt of an arbitrary sequence.
  204. return next(iter(self.seqs_dict.values())).data.prompt_token_ids
  205. def get_max_num_running_seqs(self) -> int:
  206. """The maximum number of sequences running in parallel in the remaining
  207. lifetime of the request."""
  208. if self.sampling_params.use_beam_search:
  209. # For beam search, maximally there will always be `best_of` beam
  210. # candidates running in the future.
  211. return self.sampling_params.best_of
  212. else:
  213. if self.sampling_params.best_of > self.num_seqs():
  214. # At prompt stage, the sequence group is not yet filled up
  215. # and only have one sequence running. However, in the
  216. # generation stage, we will have `best_of` sequences running.
  217. return self.sampling_params.best_of
  218. # At sampling stages, return the number of actual sequences
  219. # that are not finished yet.
  220. return self.num_unfinished_seqs()
  221. def get_seqs(
  222. self,
  223. status: Optional[SequenceStatus] = None,
  224. ) -> List[Sequence]:
  225. if status is None:
  226. return list(self.seqs_dict.values())
  227. else:
  228. return [
  229. seq for seq in self.seqs_dict.values() if seq.status == status
  230. ]
  231. def get_unfinished_seqs(self) -> List[Sequence]:
  232. return [
  233. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  234. ]
  235. def get_finished_seqs(self) -> List[Sequence]:
  236. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  237. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  238. return len(self.get_seqs(status))
  239. def num_unfinished_seqs(self) -> int:
  240. return len(self.get_unfinished_seqs())
  241. def num_finished_seqs(self) -> int:
  242. return len(self.get_finished_seqs())
  243. def find(self, seq_id: int) -> Sequence:
  244. if seq_id not in self.seqs_dict:
  245. raise ValueError(f"Sequence {seq_id} not found.")
  246. return self.seqs_dict[seq_id]
  247. def add(self, seq: Sequence) -> None:
  248. if seq.seq_id in self.seqs_dict:
  249. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  250. self.seqs_dict[seq.seq_id] = seq
  251. def remove(self, seq_id: int) -> None:
  252. if seq_id not in self.seqs_dict:
  253. raise ValueError(f"Sequence {seq_id} not found.")
  254. del self.seqs_dict[seq_id]
  255. def is_finished(self) -> bool:
  256. return all(seq.is_finished() for seq in self.get_seqs())
  257. def __repr__(self) -> str:
  258. return (f"SequenceGroup(request_id={self.request_id}, "
  259. f"sampling_params={self.sampling_params}, "
  260. f"num_seqs={len(self.seqs_dict)})")
  261. class SequenceGroupMetadata:
  262. """Metadata for a sequence group. Used to create `InputMetadata`.
  263. Args:
  264. request_id: The ID of the request.
  265. is_prompt: Whether the request is at prompt stage.
  266. seq_data: The sequence data. (Seq id -> sequence data)
  267. sampling_params: The sampling parameters used to generate the outputs.
  268. block_tables: The block tables. (Seq id -> list of physical block
  269. numbers)
  270. """
  271. def __init__(
  272. self,
  273. request_id: str,
  274. is_prompt: bool,
  275. seq_data: Dict[int, SequenceData],
  276. sampling_params: SamplingParams,
  277. block_tables: Dict[int, List[int]],
  278. persistent_data: Dict[int, dict],
  279. ) -> None:
  280. self.request_id = request_id
  281. self.is_prompt = is_prompt
  282. self.seq_data = seq_data
  283. self.sampling_params = sampling_params
  284. self.block_tables = block_tables
  285. self.persistent_data = persistent_data
  286. class SequenceOutput:
  287. """The model output associated with a sequence.
  288. Args:
  289. parent_seq_id: The ID of the parent sequence (for forking in beam
  290. search).
  291. output_token: The output token ID.
  292. logprobs: The logprobs of the output token.
  293. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  294. """
  295. def __init__(self, parent_seq_id: int, output_token: int,
  296. logprobs: Dict[int, float], persistent_data: dict) -> None:
  297. self.parent_seq_id = parent_seq_id
  298. self.output_token = output_token
  299. self.logprobs = logprobs
  300. self.persistent_data = persistent_data
  301. def __repr__(self) -> str:
  302. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  303. f"output_token={self.output_token}), "
  304. f"logprobs={self.logprobs}, "
  305. f"persistent_data={self.persistent_data}")
  306. def __eq__(self, other: object) -> bool:
  307. if not isinstance(other, SequenceOutput):
  308. raise NotImplementedError()
  309. return (self.parent_seq_id == other.parent_seq_id
  310. and self.output_token == other.output_token
  311. and self.logprobs == other.logprobs
  312. and self.persistent_data == other.persistent_data)
  313. class SequenceGroupOutput:
  314. """The model outputs associated with a sequence group."""
  315. def __init__(
  316. self,
  317. samples: List[SequenceOutput],
  318. prompt_logprobs: Optional[PromptLogprobs],
  319. ) -> None:
  320. self.samples = samples
  321. self.prompt_logprobs = prompt_logprobs
  322. def __repr__(self) -> str:
  323. return (f"SequenceGroupOutput(samples={self.samples}, "
  324. f"prompt_logprobs={self.prompt_logprobs})")
  325. def __eq__(self, other: object) -> bool:
  326. if not isinstance(other, SequenceGroupOutput):
  327. raise NotImplementedError()
  328. return (self.samples == other.samples
  329. and self.prompt_logprobs == other.prompt_logprobs)
  330. # For each sequence group, we generate a list of SequenceOutput object,
  331. # each of which contains one possible candidate for the next token.
  332. SamplerOutput = List[SequenceGroupOutput]