sequence.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, Dict, List, Optional, Union
  6. from aphrodite.common.block import LogicalTokenBlock
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.lora.request import LoRARequest
  9. if TYPE_CHECKING:
  10. import torch
  11. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  12. @dataclass
  13. class Logprob:
  14. """Infos for supporting OpenAI compatible logprobs and token ranks.
  15. Attributes:
  16. logprob: The logprob of chosen token
  17. rank: The vocab rank of chosen token (>=1)
  18. decoded_token: The decoded chosen token index
  19. """
  20. logprob: float
  21. rank: Optional[int] = None
  22. decoded_token: Optional[str] = None
  23. # {token_id -> logprob} per each sequence group. None if the corresponding
  24. # sequence group doesn't require prompt logprob.
  25. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  26. # {token_id -> logprob} for each sequence group.
  27. SampleLogprobs = List[Dict[int, Logprob]]
  28. class SequenceStatus(enum.Enum):
  29. """Status of a sequence."""
  30. WAITING = enum.auto()
  31. RUNNING = enum.auto()
  32. SWAPPED = enum.auto()
  33. FINISHED_STOPPED = enum.auto()
  34. FINISHED_LENGTH_CAPPED = enum.auto()
  35. FINISHED_ABORTED = enum.auto()
  36. FINISHED_IGNORED = enum.auto()
  37. @staticmethod
  38. def is_finished(status: "SequenceStatus") -> bool:
  39. return status in [
  40. SequenceStatus.FINISHED_STOPPED,
  41. SequenceStatus.FINISHED_LENGTH_CAPPED,
  42. SequenceStatus.FINISHED_ABORTED,
  43. SequenceStatus.FINISHED_IGNORED,
  44. ]
  45. @staticmethod
  46. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  47. if status == SequenceStatus.FINISHED_STOPPED:
  48. finish_reason = "stop"
  49. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  50. finish_reason = "length"
  51. elif status == SequenceStatus.FINISHED_ABORTED:
  52. finish_reason = "abort"
  53. elif status == SequenceStatus.FINISHED_IGNORED:
  54. # The ignored sequences are the sequences whose prompt lengths
  55. # are longer than the model's length cap. Therefore, the stop
  56. # reason should also be "length" as in OpenAI API.
  57. finish_reason = "length"
  58. else:
  59. finish_reason = None
  60. return finish_reason
  61. class SequenceStage(enum.Enum):
  62. PREFILL = enum.auto()
  63. DECODE = enum.auto()
  64. @dataclass
  65. class RequestMetrics:
  66. """Metrics associated with a request.
  67. Attributes:
  68. arrival_time: The time when the request arrived.
  69. first_scheduled_time: The time when the request was first scheduled.
  70. first_token_time: The time when the first token was generated.
  71. time_in_queue: The time the request spent in the queue.
  72. finished_time: The time when the request was finished.
  73. """
  74. arrival_time: float
  75. last_token_time: float
  76. first_scheduled_time: Optional[float]
  77. first_token_time: Optional[float]
  78. time_in_queue: Optional[float]
  79. finished_time: Optional[float] = None
  80. class SequenceData:
  81. """Data associated with a sequence.
  82. Args:
  83. prompt_token_ids: The token IDs of the prompt.
  84. output_token_ids: The token IDs of the output. Set to an empty list if
  85. None.
  86. Attributes:
  87. prompt_token_ids: The token IDs of the prompt.
  88. output_token_ids: The token IDs of the output.
  89. cumulative_logprob: The cumulative log probability of the output.
  90. """
  91. def __init__(
  92. self,
  93. prompt_token_ids: List[int],
  94. output_token_ids: Optional[List[int]] = None,
  95. ) -> None:
  96. if output_token_ids is None:
  97. output_token_ids = []
  98. self.prompt_token_ids = prompt_token_ids
  99. self.output_token_ids = output_token_ids
  100. self.cumulative_logprob = 0.0
  101. # The number of tokens that are computed (that run against the model).
  102. self._num_computed_tokens = 0
  103. self._stage: SequenceStage = SequenceStage.PREFILL
  104. def append_token_id(self, token_id: int, logprob: float) -> None:
  105. self.output_token_ids.append(token_id)
  106. self.cumulative_logprob += logprob
  107. def get_len(self) -> int:
  108. return len(self.output_token_ids) + len(self.prompt_token_ids)
  109. def get_prompt_len(self) -> int:
  110. return len(self.prompt_token_ids)
  111. def get_output_len(self) -> int:
  112. return len(self.output_token_ids)
  113. def get_token_ids(self) -> List[int]:
  114. return self.prompt_token_ids + self.output_token_ids
  115. def get_num_computed_tokens(self) -> int:
  116. """Return the number of prefill tokens that are already computed."""
  117. return self._num_computed_tokens
  118. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  119. """Update number of tokens computed so far."""
  120. self._num_computed_tokens += num_new_computed_tokens
  121. assert self._num_computed_tokens <= self.get_len(), (
  122. self._num_computed_tokens, self.get_len())
  123. # If all tokens are computed, it means it is in decoding phase.
  124. if self.get_num_uncomputed_tokens() == 0:
  125. self._stage = SequenceStage.DECODE
  126. def reset_state_for_recompute(self) -> None:
  127. """Reset the number of computed tokens from this sequence. It is
  128. supposed to be called when a sequence needs to be started from
  129. the beginning again (e.g., sequence is preempted).
  130. """
  131. self._num_computed_tokens = 0
  132. self._stage = SequenceStage.PREFILL
  133. def get_num_uncomputed_tokens(self) -> int:
  134. """Return the number of prefill tokens that are not computed."""
  135. # we use `get_len()` which includes prompt_len + output_len instead
  136. # of prompt_len here. This is because during recompute we need to
  137. # prefill for both prompt and output.
  138. return self.get_len() - self.get_num_computed_tokens()
  139. def get_last_token_id(self) -> int:
  140. if not self.output_token_ids:
  141. return self.prompt_token_ids[-1]
  142. return self.output_token_ids[-1]
  143. def get_prompt_token_ids(self) -> List[int]:
  144. return self.prompt_token_ids
  145. def get_output_token_ids(self) -> List[int]:
  146. return self.output_token_ids
  147. @property
  148. def stage(self) -> SequenceStage:
  149. return self._stage
  150. def __repr__(self) -> str:
  151. return (f"SequenceData("
  152. f"prompt_token_ids={self.prompt_token_ids}, "
  153. f"output_token_ids={self.output_token_ids}, "
  154. f"cumulative_logprob={self.cumulative_logprob})")
  155. class Sequence:
  156. """Stores the data, status, and block information of a sequence.
  157. Args:
  158. seq_id: The ID of the sequence.
  159. prompt: The prompt of the sequence.
  160. prompt_token_ids: The token IDs of the prompt.
  161. block_size: The block size of the sequence. Should be the same as the
  162. block size used by the block manager and cache engine.
  163. lora_request: LoRA request.
  164. """
  165. def __init__(
  166. self,
  167. seq_id: int,
  168. prompt: str,
  169. prompt_token_ids: List[int],
  170. block_size: int,
  171. eos_token_id: Optional[int] = None,
  172. lora_request: Optional[LoRARequest] = None,
  173. ) -> None:
  174. self.seq_id = seq_id
  175. self.prompt = prompt
  176. self.block_size = block_size
  177. self.eos_token_id = eos_token_id
  178. self.lora_request = lora_request
  179. self.data: SequenceData = SequenceData(prompt_token_ids)
  180. self.output_logprobs: SampleLogprobs = []
  181. self.output_text = ""
  182. self.logical_token_blocks: List[LogicalTokenBlock] = []
  183. # Initialize the logical token blocks with the prompt token ids.
  184. self._append_tokens_to_blocks(prompt_token_ids)
  185. self.status = SequenceStatus.WAITING
  186. self.stop_reason: Union[int, str, None] = None
  187. # Used for incremental detokenization
  188. self.prefix_offset = 0
  189. self.read_offset = 0
  190. # Input + output tokens
  191. self.tokens: Optional[List[str]] = None
  192. self.persistent_data = {}
  193. @property
  194. def lora_int_id(self) -> int:
  195. return self.lora_request.lora_int_id if self.lora_request else 0
  196. def get_output_text_to_return(self, buffer_length: int):
  197. # We return the full output text if the sequence is finished.
  198. truncate = buffer_length and not self.is_finished()
  199. return self.output_text[:-buffer_length] if truncate else (
  200. self.output_text)
  201. def hash_of_block(self, logical_idx: int) -> int:
  202. # TODO: This can produce incorrect hash when block size > prompt size
  203. # Compute the number of tokens in the sequence
  204. # TODO: The current hashing function is O(L^2). We should optimize
  205. # this in the future.
  206. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  207. return hash(
  208. (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
  209. def num_hashed_tokens_of_block(self, logical_idx: int):
  210. return logical_idx * self.block_size + self.block_size
  211. def reset_state_for_recompute(self):
  212. """Reset the sequence states for recomputation."""
  213. self.data.reset_state_for_recompute()
  214. def _append_logical_block(self) -> None:
  215. block = LogicalTokenBlock(
  216. block_number=len(self.logical_token_blocks),
  217. block_size=self.block_size,
  218. )
  219. self.logical_token_blocks.append(block)
  220. def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
  221. cursor = 0
  222. while cursor < len(token_ids):
  223. if not self.logical_token_blocks:
  224. self._append_logical_block()
  225. last_block = self.logical_token_blocks[-1]
  226. if last_block.is_full():
  227. self._append_logical_block()
  228. last_block = self.logical_token_blocks[-1]
  229. num_empty_slots = last_block.get_num_empty_slots()
  230. last_block.append_tokens(token_ids[cursor:cursor +
  231. num_empty_slots])
  232. cursor += num_empty_slots
  233. def append_token_id(
  234. self,
  235. token_id: int,
  236. logprobs: Dict[int, Logprob],
  237. ) -> None:
  238. assert token_id in logprobs
  239. self._append_tokens_to_blocks([token_id])
  240. self.output_logprobs.append(logprobs)
  241. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  242. def get_len(self) -> int:
  243. return self.data.get_len()
  244. def get_prompt_len(self) -> int:
  245. return self.data.get_prompt_len()
  246. def get_output_len(self) -> int:
  247. return self.data.get_output_len()
  248. def get_token_ids(self) -> List[int]:
  249. return self.data.get_token_ids()
  250. def get_prompt_token_ids(self) -> List[int]:
  251. return self.data.get_prompt_token_ids()
  252. def get_last_token_id(self) -> int:
  253. return self.data.get_last_token_id()
  254. def get_output_token_ids(self) -> List[int]:
  255. return self.data.output_token_ids
  256. def get_cumulative_logprob(self) -> float:
  257. return self.data.cumulative_logprob
  258. def get_beam_search_score(self,
  259. length_penalty: float = 1.0,
  260. seq_len: Optional[int] = None,
  261. eos_token_id: Optional[int] = None) -> float:
  262. """Calculate the beam search score with length penalty.
  263. Adapted from
  264. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  265. """
  266. if seq_len is None:
  267. seq_len = self.get_len()
  268. # NOTE: HF implementation does not count the EOS token
  269. # towards the length, we align with that here for testing.
  270. if (eos_token_id is not None
  271. and self.get_last_token_id() == eos_token_id):
  272. seq_len -= 1
  273. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  274. def is_finished(self) -> bool:
  275. return SequenceStatus.is_finished(self.status)
  276. def fork(self, new_seq_id: int) -> "Sequence":
  277. new_seq = copy.deepcopy(self)
  278. new_seq.seq_id = new_seq_id
  279. return new_seq
  280. def get_num_new_tokens(self) -> int:
  281. """Get the number of new tokens to be computed.
  282. Returns:
  283. The new number of tokens to be computed. I.e., 1 for decode, or
  284. the remaining prompt size for prefill.
  285. """
  286. if self.data.stage == SequenceStage.DECODE:
  287. return 1
  288. return self.data.get_num_uncomputed_tokens()
  289. def is_prefill(self) -> bool:
  290. return self.data.stage == SequenceStage.PREFILL
  291. def __repr__(self) -> str:
  292. return (f"Sequence(seq_id={self.seq_id}, "
  293. f"status={self.status.name}, "
  294. f"num_blocks={len(self.logical_token_blocks)})")
  295. @dataclass
  296. class SequenceGroupState:
  297. """Mutable state tied to a specific sequence group"""
  298. # torch.Generator used in seeded sampling
  299. generator: Optional = None # type: ignore
  300. class MultiModalData:
  301. """Multi modal request.
  302. Args:
  303. type: The data type.
  304. data: The actual data.
  305. The required shape and semantic meaning of it depends on the vision
  306. language config of the hosted model.
  307. See `VisionLanguageConfig` in `config.py`.
  308. """
  309. class Type(enum.Enum):
  310. IMAGE = enum.auto()
  311. def __init__(self, type: Type, data: "torch.Tensor"):
  312. self.type = type
  313. self.data = data
  314. class SequenceGroup:
  315. """A group of sequences that are generated from the same prompt.
  316. Args:
  317. request_id: The ID of the request.
  318. seqs: The list of sequences.
  319. sampling_params: The sampling parameters used to generate the outputs.
  320. arrival_time: The arrival time of the request.
  321. lora_request: LoRA request.
  322. multi_modal_data: Multi modal data associated with the request.
  323. """
  324. def __init__(
  325. self,
  326. request_id: str,
  327. seqs: List[Sequence],
  328. sampling_params: SamplingParams,
  329. arrival_time: float,
  330. lora_request: Optional[LoRARequest] = None,
  331. multi_modal_data: Optional[MultiModalData] = None,
  332. ) -> None:
  333. self.request_id = request_id
  334. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  335. self.sampling_params = sampling_params
  336. self.metrics = RequestMetrics(arrival_time=arrival_time,
  337. last_token_time=arrival_time,
  338. first_scheduled_time=None,
  339. first_token_time=None,
  340. time_in_queue=None)
  341. self.lora_request = lora_request
  342. self.prompt_logprobs: Optional[PromptLogprobs] = None
  343. self.state = SequenceGroupState()
  344. self.multi_modal_data = multi_modal_data
  345. @property
  346. def prompt(self) -> str:
  347. # All sequences in the group should have the same prompt.
  348. # We use the prompt of an arbitrary sequence.
  349. return next(iter(self.seqs_dict.values())).prompt
  350. @property
  351. def prompt_token_ids(self) -> List[int]:
  352. # All sequences in the group should have the same prompt.
  353. # We use the prompt of an arbitrary sequence.
  354. return next(iter(self.seqs_dict.values())).data.prompt_token_ids
  355. @property
  356. def lora_int_id(self) -> int:
  357. return self.lora_request.lora_int_id if self.lora_request else 0
  358. def get_last_latency(self, now: float) -> float:
  359. """Gets last token latency for Request level timings."""
  360. latency = now - self.metrics.last_token_time
  361. self.metrics.last_token_time = now
  362. return latency
  363. def maybe_set_first_token_time(self, time: float) -> None:
  364. """Sets the first token time for Request level timings."""
  365. if self.metrics.first_token_time is None:
  366. self.metrics.first_token_time = time
  367. def maybe_set_first_scheduled_time(self, time: float) -> None:
  368. """Sets the first scheduled time and time in queue for Request
  369. level timings."""
  370. if self.metrics.first_scheduled_time is None:
  371. self.metrics.first_scheduled_time = time
  372. self.metrics.time_in_queue = time - self.metrics.arrival_time
  373. def set_finished_time(self, time: Optional[float]) -> None:
  374. """Sets the finished time for Request level timings."""
  375. self.metrics.finished_time = time
  376. def get_max_num_running_seqs(self) -> int:
  377. """The maximum number of sequences running in parallel in the remaining
  378. lifetime of the request."""
  379. if self.sampling_params.use_beam_search:
  380. # For beam search, maximally there will always be `best_of` beam
  381. # candidates running in the future.
  382. return self.sampling_params.best_of
  383. else:
  384. if self.sampling_params.best_of > self.num_seqs():
  385. # At prompt stage, the sequence group is not yet filled up
  386. # and only have one sequence running. However, in the
  387. # generation stage, we will have `best_of` sequences running.
  388. return self.sampling_params.best_of
  389. # At sampling stages, return the number of actual sequences
  390. # that are not finished yet.
  391. return self.num_unfinished_seqs()
  392. def get_seqs(
  393. self,
  394. status: Optional[SequenceStatus] = None,
  395. ) -> List[Sequence]:
  396. return list(self.seqs_dict.values()) if status is None else [
  397. seq for seq in self.seqs_dict.values() if seq.status == status
  398. ]
  399. def get_unfinished_seqs(self) -> List[Sequence]:
  400. return [
  401. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  402. ]
  403. def get_finished_seqs(self) -> List[Sequence]:
  404. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  405. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  406. """Update number of tokens computed so far."""
  407. for seq in self.seqs_dict.values():
  408. if not seq.is_finished():
  409. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  410. def get_num_uncomputed_tokens(self) -> int:
  411. num_uncomputed_tokens = 0
  412. for seq in self.get_seqs():
  413. if not seq.is_finished():
  414. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  415. return num_uncomputed_tokens
  416. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  417. # Optimization. We don't need to call get_seqs if we don't need to
  418. # filter by states.
  419. if status is None:
  420. return len(self.seqs_dict)
  421. return len(self.get_seqs(status))
  422. def num_unfinished_seqs(self) -> int:
  423. return len(self.get_unfinished_seqs())
  424. def num_finished_seqs(self) -> int:
  425. return len(self.get_finished_seqs())
  426. def find(self, seq_id: int) -> Sequence:
  427. if seq_id not in self.seqs_dict:
  428. raise ValueError(f"Sequence {seq_id} not found.")
  429. return self.seqs_dict[seq_id]
  430. def add(self, seq: Sequence) -> None:
  431. if seq.seq_id in self.seqs_dict:
  432. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  433. self.seqs_dict[seq.seq_id] = seq
  434. def remove(self, seq_id: int) -> None:
  435. if seq_id not in self.seqs_dict:
  436. raise ValueError(f"Sequence {seq_id} not found.")
  437. del self.seqs_dict[seq_id]
  438. def is_finished(self) -> bool:
  439. return all(seq.is_finished() for seq in self.get_seqs())
  440. def is_prefill(self) -> bool:
  441. # Every sequences should be in the same stage.
  442. return self.get_seqs()[0].is_prefill()
  443. def __repr__(self) -> str:
  444. return (f"SequenceGroup(request_id={self.request_id}, "
  445. f"sampling_params={self.sampling_params}, "
  446. f"num_seqs={len(self.seqs_dict)})")
  447. class SequenceGroupMetadata:
  448. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  449. Args:
  450. request_id: The ID of the request.
  451. is_prompt: Whether the request is at prompt stage.
  452. seq_data: The sequence data. (Seq id -> sequence data)
  453. sampling_params: The sampling parameters used to generate the outputs.
  454. block_tables: The block tables. (Seq id -> list of physical block
  455. numbers)
  456. do_sample: True if sampling is required. Sampling is not required when
  457. e.g., prefill is chunked, and the current iteration only computes
  458. query tokens for prefill, we don't need sampling.
  459. token_chunk_size: The number of tokens to be processed (per sequence).
  460. None if chunking is not required.
  461. state: Internal state tied to this sequence group.
  462. lora_request: LoRA request.
  463. multi_modal_data: Multi modal data for the request.
  464. persistent_data: The persistent data of the sequence group.
  465. """
  466. def __init__(
  467. self,
  468. request_id: str,
  469. is_prompt: bool,
  470. seq_data: Dict[int, SequenceData],
  471. sampling_params: SamplingParams,
  472. block_tables: Dict[int, List[int]],
  473. persistent_data: Dict[int, dict],
  474. do_sample: bool = True,
  475. token_chunk_size: Optional[int] = None,
  476. lora_request: Optional[LoRARequest] = None,
  477. computed_block_nums: Optional[List[int]] = None,
  478. state: Optional[SequenceGroupState] = None,
  479. multi_modal_data: Optional[MultiModalData] = None,
  480. ) -> None:
  481. self.request_id = request_id
  482. self.is_prompt = is_prompt
  483. self.seq_data = seq_data
  484. self.sampling_params = sampling_params
  485. self.block_tables = block_tables
  486. self.persistent_data = persistent_data
  487. self.lora_request = lora_request
  488. self.computed_block_nums = computed_block_nums
  489. self.multi_modal_data = multi_modal_data
  490. self.state = SequenceGroupState() if state is None else state
  491. self._token_chunk_size = token_chunk_size
  492. self.do_sample = do_sample
  493. if self._token_chunk_size is None:
  494. if is_prompt:
  495. self._token_chunk_size = list(seq_data.values())[0].get_len()
  496. else:
  497. self._token_chunk_size = 1
  498. @property
  499. def lora_int_id(self) -> int:
  500. return self.lora_request.lora_int_id if self.lora_request else 0
  501. @property
  502. def token_chunk_size(self) -> Optional[int]:
  503. """Return the number of tokens to be processed (chunk size)."""
  504. return self._token_chunk_size
  505. class SequenceOutput:
  506. """The model output associated with a sequence.
  507. Args:
  508. parent_seq_id: The ID of the parent sequence (for forking in beam
  509. search).
  510. output_token: The output token ID.
  511. logprobs: The logprobs of the output token.
  512. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  513. persistent_data: The persistent data of the sequence.
  514. """
  515. def __init__(
  516. self,
  517. parent_seq_id: int,
  518. output_token: int,
  519. logprobs: Dict[int, Logprob],
  520. persistent_data: dict,
  521. ) -> None:
  522. self.parent_seq_id = parent_seq_id
  523. self.output_token = output_token
  524. self.logprobs = logprobs
  525. self.persistent_data = persistent_data
  526. def __repr__(self) -> str:
  527. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  528. f"output_token={self.output_token}, "
  529. f"logprobs={self.logprobs}, "
  530. f"persistent_data={self.persistent_data})")
  531. def __eq__(self, other: object) -> bool:
  532. if not isinstance(other, SequenceOutput):
  533. raise NotImplementedError()
  534. equal = (self.parent_seq_id == other.parent_seq_id
  535. and self.output_token == other.output_token)
  536. log_probs_equal = other.logprobs == self.logprobs
  537. return equal and log_probs_equal
  538. class SequenceGroupOutput:
  539. """The model output associated with a sequence group."""
  540. def __init__(
  541. self,
  542. samples: List[SequenceOutput],
  543. prompt_logprobs: Optional[PromptLogprobs],
  544. ) -> None:
  545. self.samples = samples
  546. # Prompt logprob for each prompt query token.
  547. self.prompt_logprobs = prompt_logprobs
  548. def __repr__(self) -> str:
  549. return (f"SequenceGroupOutput(samples={self.samples}, "
  550. f"prompt_logprobs={self.prompt_logprobs})")
  551. def __eq__(self, other: object) -> bool:
  552. if not isinstance(other, SequenceGroupOutput):
  553. raise NotImplementedError()
  554. return (self.samples == other.samples
  555. and self.prompt_logprobs == other.prompt_logprobs)
  556. @dataclass
  557. class SamplerOutput:
  558. """For each sequence group, we generate a list of SequenceOutput object,
  559. each of which contains one possible candidate for the next token.
  560. This datastructure implements methods so it can be used like a list, but
  561. also has optional fields for device tensors.
  562. """
  563. outputs: List[SequenceGroupOutput]
  564. # On-device tensor containing probabilities of each token.
  565. sampled_token_probs: Optional["torch.Tensor"] = None
  566. # On-device tensor containing the sampled token ids.
  567. sampled_token_ids: Optional["torch.Tensor"] = None
  568. # Spec decode metrics populated by workers.
  569. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  570. def __getitem__(self, idx: int):
  571. return self.outputs[idx]
  572. def __setitem__(self, idx: int, value):
  573. self.outputs[idx] = value
  574. def __len__(self):
  575. return len(self.outputs)
  576. def __eq__(self, other: object):
  577. return isinstance(other,
  578. self.__class__) and self.outputs == other.outputs
  579. def __repr__(self) -> str:
  580. """Show the shape of a tensor instead of its values to reduce noise.
  581. """
  582. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  583. else self.sampled_token_probs.shape)
  584. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  585. self.sampled_token_ids.shape)
  586. return (
  587. f"SamplerOutput(outputs={self.outputs}, "
  588. f"sampled_token_probs={sampled_token_probs_repr}, "
  589. f"sampled_token_ids={sampled_token_ids_repr}, "
  590. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")