1
0

sequence.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from dataclasses import dataclass
  5. from typing import Dict, List, Optional, Union, TYPE_CHECKING
  6. from aphrodite.common.block import LogicalTokenBlock
  7. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  8. from aphrodite.lora.request import LoRARequest
  9. if TYPE_CHECKING:
  10. import torch
  11. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  12. @dataclass
  13. class Logprob:
  14. """Infos for supporting OpenAI compatible logprobs and token ranks.
  15. Attributes:
  16. logprob: The logprob of chosen token
  17. rank: The vocab rank of chosen token (>=1)
  18. decoded_token: The decoded chosen token index
  19. """
  20. logprob: float
  21. rank: Optional[int] = None
  22. decoded_token: Optional[str] = None
  23. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  24. SampleLogprobs = List[Dict[int, Logprob]]
  25. class SequenceStatus(enum.Enum):
  26. """Status of a sequence."""
  27. WAITING = enum.auto()
  28. RUNNING = enum.auto()
  29. SWAPPED = enum.auto()
  30. FINISHED_STOPPED = enum.auto()
  31. FINISHED_LENGTH_CAPPED = enum.auto()
  32. FINISHED_ABORTED = enum.auto()
  33. FINISHED_IGNORED = enum.auto()
  34. @staticmethod
  35. def is_finished(status: "SequenceStatus") -> bool:
  36. return status in [
  37. SequenceStatus.FINISHED_STOPPED,
  38. SequenceStatus.FINISHED_LENGTH_CAPPED,
  39. SequenceStatus.FINISHED_ABORTED,
  40. SequenceStatus.FINISHED_IGNORED,
  41. ]
  42. @staticmethod
  43. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  44. if status == SequenceStatus.FINISHED_STOPPED:
  45. finish_reason = "stop"
  46. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  47. finish_reason = "length"
  48. elif status == SequenceStatus.FINISHED_ABORTED:
  49. finish_reason = "abort"
  50. elif status == SequenceStatus.FINISHED_IGNORED:
  51. # The ignored sequences are the sequences whose prompt lengths
  52. # are longer than the model's length cap. Therefore, the stop
  53. # reason should also be "length" as in OpenAI API.
  54. finish_reason = "length"
  55. else:
  56. finish_reason = None
  57. return finish_reason
  58. class SequenceStage(enum.Enum):
  59. PREFILL = enum.auto()
  60. DECODE = enum.auto()
  61. @dataclass
  62. class RequestMetrics:
  63. """Metrics associated with a request.
  64. Args:
  65. arrival_time: The time when the request arrived.
  66. first_scheduled_time: The time when the request was first scheduled.
  67. first_token_time: The time when the first token was generated.
  68. time_in_queue: The time the request spent in the queue.
  69. finished_time: The time when the request was finished.
  70. """
  71. arrival_time: float
  72. last_token_time: float
  73. first_scheduled_time: Optional[float]
  74. first_token_time: Optional[float]
  75. time_in_queue: Optional[float]
  76. finished_time: Optional[float] = None
  77. class SequenceData:
  78. """Data associated with a sequence.
  79. Args:
  80. prompt_token_ids: The token IDs of the prompt.
  81. output_token_ids: The token IDs of the output. Set to an empty list if
  82. None.
  83. Attributes:
  84. prompt_token_ids: The token IDs of the prompt.
  85. output_token_ids: The token IDs of the output.
  86. cumulative_logprob: The cumulative log probability of the output.
  87. """
  88. def __init__(
  89. self,
  90. prompt_token_ids: List[int],
  91. output_token_ids: Optional[List[int]] = None,
  92. ) -> None:
  93. if output_token_ids is None:
  94. output_token_ids = []
  95. self.prompt_token_ids = prompt_token_ids
  96. self.output_token_ids = output_token_ids
  97. self.cumulative_logprob = 0.0
  98. # The number of tokens that are computed (that run against the model).
  99. self._num_computed_tokens = 0
  100. self._stage: SequenceStage = SequenceStage.PREFILL
  101. # The ID to identify the sequences in the same group
  102. self.inner_id = 0
  103. def append_token_id(self, token_id: int, logprob: float) -> None:
  104. self.output_token_ids.append(token_id)
  105. self.cumulative_logprob += logprob
  106. def get_len(self) -> int:
  107. return len(self.output_token_ids) + len(self.prompt_token_ids)
  108. def get_prompt_len(self) -> int:
  109. return len(self.prompt_token_ids)
  110. def get_output_len(self) -> int:
  111. return len(self.output_token_ids)
  112. def get_token_ids(self) -> List[int]:
  113. return self.prompt_token_ids + self.output_token_ids
  114. def get_num_computed_tokens(self) -> int:
  115. """Return the number of prefill tokens that are already computed."""
  116. return self._num_computed_tokens
  117. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  118. """Update number of tokens computed so far."""
  119. self._num_computed_tokens += num_new_computed_tokens
  120. assert self._num_computed_tokens <= self.get_len(), (
  121. self._num_computed_tokens, self.get_len())
  122. # If all tokens are computed, it means it is in decoding phase.
  123. if self.get_num_uncomputed_tokens() == 0:
  124. self._stage = SequenceStage.DECODE
  125. def reset_state_for_recompute(self) -> None:
  126. """Reset the number of computed tokens from this sequence. It is
  127. supposed to be called when a sequence needs to be started from
  128. the beginning again (e.g., sequence is preempted).
  129. """
  130. self._num_computed_tokens = 0
  131. self._stage = SequenceStage.PREFILL
  132. def get_num_uncomputed_tokens(self) -> int:
  133. """Return the number of prefil tokens that are not computed."""
  134. # we use `get_len()` which includes prompt_len + output_len instead
  135. # of prompt_len here. This is because during recompute we need to
  136. # prefill for both prompt and output.
  137. return self.get_len() - self.get_num_computed_tokens()
  138. def get_last_token_id(self) -> int:
  139. if not self.output_token_ids:
  140. return self.prompt_token_ids[-1]
  141. return self.output_token_ids[-1]
  142. def get_prompt_token_ids(self) -> int:
  143. return self.prompt_token_ids
  144. def get_output_token_ids(self) -> int:
  145. return self.output_token_ids
  146. @property
  147. def stage(self) -> SequenceStage:
  148. return self._stage
  149. def __repr__(self) -> str:
  150. return (f"SequenceData("
  151. f"prompt_token_ids={self.prompt_token_ids}, "
  152. f"output_token_ids={self.output_token_ids}, "
  153. f"cumulative_logprob={self.cumulative_logprob})")
  154. class Sequence:
  155. """Stores the data, status, and block information of a sequence.
  156. Args:
  157. seq_id: The ID of the sequence.
  158. prompt: The prompt of the sequence.
  159. prompt_token_ids: The token IDs of the prompt.
  160. block_size: The block size of the sequence. Should be the same as the
  161. block size used by the block manager and cache engine.
  162. lora_request: LoRA request.
  163. """
  164. def __init__(
  165. self,
  166. seq_id: int,
  167. prompt: str,
  168. prompt_token_ids: List[int],
  169. block_size: int,
  170. eos_token_id: Optional[int] = None,
  171. lora_request: Optional[LoRARequest] = None,
  172. ) -> None:
  173. self.seq_id = seq_id
  174. self.prompt = prompt
  175. self.block_size = block_size
  176. self.eos_token_id = eos_token_id
  177. self.lora_request = lora_request
  178. self.data = SequenceData(prompt_token_ids)
  179. self.output_logprobs: SampleLogprobs = []
  180. self.output_text = ""
  181. self.logical_token_blocks: List[LogicalTokenBlock] = []
  182. # Initialize the logical token blocks with the prompt token ids.
  183. self._append_tokens_to_blocks(prompt_token_ids)
  184. self.status = SequenceStatus.WAITING
  185. self.stop_reason: Union[int, str, None] = None
  186. # Used for incremental detokenization
  187. self.prefix_offset = 0
  188. self.read_offset = 0
  189. # Input + output tokens
  190. self.tokens: Optional[List[str]] = None
  191. self.persistent_data = {}
  192. self.seq_group: Optional[SequenceGroup] = None
  193. # The id to identify the sequences in same Group
  194. self.inner_id = 0
  195. @property
  196. def lora_int_id(self) -> int:
  197. return self.lora_request.lora_int_id if self.lora_request else 0
  198. def get_output_text_to_return(self, buffer_length: int):
  199. # We return the full output text if the sequence is finished.
  200. truncate = buffer_length and not self.is_finished()
  201. return self.output_text[:-buffer_length] if truncate else (
  202. self.output_text)
  203. def hash_of_block(self, logical_idx: int) -> int:
  204. # Compute the number of tokens in the sequence
  205. # TODO: The current hashing function is O(L^2). We should optimize
  206. # this in the future.
  207. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  208. return hash(
  209. (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
  210. def num_hashed_tokens_of_block(self, logical_idx: int):
  211. return logical_idx * self.block_size + self.block_size
  212. def reset_state_for_recompute(self):
  213. """Reset the sequence states for recomputation."""
  214. self.data.reset_state_for_recompute()
  215. def _append_logical_block(self) -> None:
  216. block = LogicalTokenBlock(
  217. block_number=len(self.logical_token_blocks),
  218. block_size=self.block_size,
  219. )
  220. self.logical_token_blocks.append(block)
  221. def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
  222. cursor = 0
  223. while cursor < len(token_ids):
  224. if not self.logical_token_blocks:
  225. self._append_logical_block()
  226. last_block = self.logical_token_blocks[-1]
  227. if last_block.is_full():
  228. self._append_logical_block()
  229. last_block = self.logical_token_blocks[-1]
  230. num_empty_slots = last_block.get_num_empty_slots()
  231. last_block.append_tokens(token_ids[cursor:cursor +
  232. num_empty_slots])
  233. cursor += num_empty_slots
  234. def append_token_id(
  235. self,
  236. token_id: int,
  237. logprobs: Dict[int, Logprob],
  238. ) -> None:
  239. assert token_id in logprobs
  240. seq_group = self.seq_group
  241. def calc(seq_group):
  242. num_token = 0
  243. for block in seq_group.find(
  244. seq_group.root_seq_id).logical_token_blocks:
  245. num_token += block.num_tokens
  246. return num_token
  247. # allocate block for root seq if sampling type is RANDOM_SEED and
  248. # generate multiple sequence per prompt
  249. if self.seq_id != seq_group.root_seq_id and \
  250. seq_group.sampling_params.sampling_type in (
  251. SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  252. buffer_seq = seq_group.find(self.seq_group.root_seq_id)
  253. else:
  254. buffer_seq = self
  255. buffer_seq._append_tokens_to_blocks([token_id])
  256. self.output_logprobs.append(logprobs)
  257. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  258. def get_len(self) -> int:
  259. return self.data.get_len()
  260. def get_prompt_len(self) -> int:
  261. return self.data.get_prompt_len()
  262. def get_output_len(self) -> int:
  263. return self.data.get_output_len()
  264. def get_token_ids(self) -> List[int]:
  265. return self.data.get_token_ids()
  266. def get_prompt_token_ids(self) -> List[int]:
  267. return self.data.get_prompt_token_ids()
  268. def get_last_token_id(self) -> int:
  269. return self.data.get_last_token_id()
  270. def get_output_token_ids(self) -> List[int]:
  271. return self.data.output_token_ids
  272. def get_cumulative_logprob(self) -> float:
  273. return self.data.cumulative_logprob
  274. def get_beam_search_score(
  275. self,
  276. length_penalty: float = 1.0,
  277. seq_len: Optional[int] = None,
  278. eos_token_id: Optional[int] = None,
  279. ) -> float:
  280. """Calculate the beam search score with length penalty.
  281. Adapted from
  282. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  283. """
  284. if seq_len is None:
  285. seq_len = self.get_len()
  286. # NOTE: HF implementation does not count the EOS token
  287. # towards the length, we align with that here for testing.
  288. if (eos_token_id is not None
  289. and self.get_last_token_id() == eos_token_id):
  290. seq_len -= 1
  291. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  292. def is_finished(self) -> bool:
  293. return SequenceStatus.is_finished(self.status)
  294. def fork(self, new_seq_id: int) -> "Sequence":
  295. seq_group = self.seq_group
  296. self.seq_group = None
  297. new_seq = copy.deepcopy(self)
  298. self.seq_group = seq_group
  299. new_seq.seq_group = seq_group
  300. new_seq.seq_id = new_seq_id
  301. return new_seq
  302. def get_num_new_tokens(self) -> int:
  303. """Get the number of new tokens to be computed.
  304. Args:
  305. remainig_token_budget: The remaining token budgets.
  306. Returns:
  307. The new number of tokens to be computed. I.e., 1 for decode, prompt
  308. size for prefill. If there's not enough remainig_token_budget, it
  309. can return the chunked number of new tokens.
  310. """
  311. if self.data.stage == SequenceStage.DECODE:
  312. return 1
  313. return self.data.get_num_uncomputed_tokens()
  314. def is_prefill(self) -> bool:
  315. return self.data.stage == SequenceStage.PREFILL
  316. def __repr__(self) -> str:
  317. return (f"Sequence(seq_id={self.seq_id}, "
  318. f"status={self.status.name}, "
  319. f"num_blocks={len(self.logical_token_blocks)})")
  320. @dataclass
  321. class SequenceGroupState:
  322. """Mutable state tied to a specific sequence group"""
  323. # torch.Generator used in seeded sampling
  324. generator: Optional = None
  325. class MultiModalData:
  326. """Multi modal request.
  327. Args:
  328. type: The data type.
  329. data: The actual data.
  330. The required shape and semantic meaning of it depends on the vision
  331. language config of the hosted model.
  332. See `VisionLanguageConfig` in `config.py`.
  333. """
  334. class Type(enum.Enum):
  335. IMAGE = enum.auto()
  336. def __init__(self, type: Type, data: "torch.Tensor"):
  337. self.type = type
  338. self.data = data
  339. class SequenceGroup:
  340. """A group of sequences that are generated from the same prompt.
  341. Args:
  342. request_id: The ID of the request.
  343. seqs: The list of sequences.
  344. sampling_params: The sampling parameters used to generate the outputs.
  345. arrival_time: The arrival time of the request.
  346. lora_request: LoRA request.
  347. multi_modal_requst: Multi modal data for the request.
  348. """
  349. def __init__(
  350. self,
  351. request_id: str,
  352. seqs: List[Sequence],
  353. sampling_params: SamplingParams,
  354. arrival_time: float,
  355. lora_request: Optional[LoRARequest] = None,
  356. multi_modal_data: Optional[MultiModalData] = None,
  357. ) -> None:
  358. self.request_id = request_id
  359. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  360. self.sampling_params = sampling_params
  361. self.metrics = RequestMetrics(
  362. arrival_time=arrival_time,
  363. last_token_time=arrival_time,
  364. first_scheduled_time=None,
  365. first_token_time=None,
  366. time_in_queue=None,
  367. )
  368. self.lora_request = lora_request
  369. self.prompt_logprobs: Optional[PromptLogprobs] = None
  370. self.state = SequenceGroupState()
  371. self.multi_modal_data = multi_modal_data
  372. # Sequence group should have only one sequence when init.
  373. seqs[0].seq_group = self
  374. self.root_seq_id = seqs[0].seq_id
  375. @property
  376. def prompt(self) -> str:
  377. # All sequences in the group should have the same prompt.
  378. # We use the prompt of an arbitrary sequence.
  379. return next(iter(self.seqs_dict.values())).prompt
  380. @property
  381. def prompt_token_ids(self) -> List[int]:
  382. # All sequences in the group should have the same prompt.
  383. # We use the prompt of an arbitrary sequence.
  384. return next(iter(self.seqs_dict.values())).data.prompt_token_ids
  385. @property
  386. def lora_int_id(self) -> int:
  387. return self.lora_request.lora_int_id if self.lora_request else 0
  388. def get_last_latency(self, now: float) -> float:
  389. """Gets last token latency for Request level timings."""
  390. latency = now - self.metrics.last_token_time
  391. self.metrics.last_token_time = now
  392. return latency
  393. def maybe_set_first_token_time(self, time: float) -> None:
  394. """Sets the first token time for Request level timings."""
  395. if self.metrics.first_token_time is None:
  396. self.metrics.first_token_time = time
  397. def maybe_set_first_scheduled_time(self, time: float) -> None:
  398. """Sets the first scheduled time and time in queue for Request level
  399. timings."""
  400. if self.metrics.first_scheduled_time is None:
  401. self.metrics.first_scheduled_time = time
  402. self.metrics.time_in_queue = time - self.metrics.arrival_time
  403. def set_finished_time(self, time: Optional[float]) -> None:
  404. """Sets the finished time for Request level timings."""
  405. self.metrics.finished_time = time
  406. def get_max_num_running_seqs(self) -> int:
  407. """The maximum number of sequences running in parallel in the remaining
  408. lifetime of the request."""
  409. if self.sampling_params.use_beam_search:
  410. # For beam search, maximally there will always be `best_of` beam
  411. # candidates running in the future.
  412. return self.sampling_params.best_of
  413. else:
  414. if self.sampling_params.best_of > self.num_seqs():
  415. # At prompt stage, the sequence group is not yet filled up
  416. # and only have one sequence running. However, in the
  417. # generation stage, we will have `best_of` sequences running.
  418. return self.sampling_params.best_of
  419. # At sampling stages, return the number of actual sequences
  420. # that are not finished yet.
  421. return self.num_unfinished_seqs()
  422. def get_seqs(
  423. self,
  424. status: Optional[SequenceStatus] = None,
  425. ) -> List[Sequence]:
  426. return (list(self.seqs_dict.values()) if status is None else [
  427. seq for seq in self.seqs_dict.values() if seq.status == status
  428. ])
  429. def get_unfinished_seqs(self) -> List[Sequence]:
  430. return [
  431. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  432. ]
  433. def get_finished_seqs(self) -> List[Sequence]:
  434. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  435. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  436. """Update number of tokens computed so far."""
  437. for seq in self.seqs_dict.values():
  438. if not seq.is_finished():
  439. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  440. def get_num_uncomputed_tokens(self) -> int:
  441. num_uncomputed_tokens = 0
  442. for seq in self.get_seqs():
  443. if not seq.is_finished():
  444. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  445. return num_uncomputed_tokens
  446. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  447. return len(self.get_seqs(status))
  448. def num_unfinished_seqs(self) -> int:
  449. return len(self.get_unfinished_seqs())
  450. def num_finished_seqs(self) -> int:
  451. return len(self.get_finished_seqs())
  452. def find(self, seq_id: int) -> Sequence:
  453. if seq_id not in self.seqs_dict:
  454. raise ValueError(f"Sequence {seq_id} not found.")
  455. return self.seqs_dict[seq_id]
  456. def add(self, seq: Sequence) -> None:
  457. if seq.seq_id in self.seqs_dict:
  458. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  459. seq.seq_group = self
  460. seq.inner_id = len(self.seqs_dict)
  461. seq.data.inner_id = seq.inner_id
  462. self.seqs_dict[seq.seq_id] = seq
  463. def remove(self, seq_id: int) -> None:
  464. if seq_id not in self.seqs_dict:
  465. raise ValueError(f"Sequence {seq_id} not found.")
  466. del self.seqs_dict[seq_id]
  467. def is_finished(self) -> bool:
  468. return all(seq.is_finished() for seq in self.get_seqs())
  469. def is_prefill(self) -> bool:
  470. # Every sequences should be in the same stage.
  471. return self.get_seqs()[0].is_prefill()
  472. def get_root(self) -> Sequence:
  473. return self.find(self.root_seq_id)
  474. def __repr__(self) -> str:
  475. return (f"SequenceGroup(request_id={self.request_id}, "
  476. f"sampling_params={self.sampling_params}, "
  477. f"num_seqs={len(self.seqs_dict)})")
  478. class SequenceGroupMetadata:
  479. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  480. Args:
  481. request_id: The ID of the request.
  482. is_prompt: Whether the request is at prompt stage.
  483. seq_data: The sequence data. (Seq id -> sequence data)
  484. sampling_params: The sampling parameters used to generate the outputs.
  485. block_tables: The block tables. (Seq id -> list of physical block
  486. numbers)
  487. token_chunk_size: The number of tokens to be processed (per sequence).
  488. None if chunking is not required.
  489. state: Internal state tied to this sequence group.
  490. lora_request: LoRA request.
  491. multi_modal_data: Multi modal data for the request.
  492. persistent_data: The persistent data of the sequence group.
  493. """
  494. def __init__(
  495. self,
  496. request_id: str,
  497. is_prompt: bool,
  498. seq_data: Dict[int, SequenceData],
  499. sampling_params: SamplingParams,
  500. block_tables: Dict[int, List[int]],
  501. persistent_data: Dict[int, dict],
  502. token_chunk_size: Optional[int] = None,
  503. lora_request: Optional[LoRARequest] = None,
  504. computed_block_nums: Optional[List[int]] = None,
  505. state: Optional[SequenceGroupState] = None,
  506. multi_modal_data: Optional[MultiModalData] = None,
  507. root_seq_id: Optional[int] = None,
  508. ) -> None:
  509. self.request_id = request_id
  510. self.is_prompt = is_prompt
  511. self.seq_data = seq_data
  512. self.sampling_params = sampling_params
  513. self.block_tables = block_tables
  514. self.persistent_data = persistent_data
  515. self.lora_request = lora_request
  516. self.computed_block_nums = computed_block_nums
  517. self.state = SequenceGroupState() if state is None else state
  518. self.multi_modal_data = multi_modal_data
  519. self._token_chunk_size = token_chunk_size
  520. self.root_seq_id = root_seq_id
  521. if self._token_chunk_size is None:
  522. if is_prompt:
  523. self._token_chunk_size = list(seq_data.values())[0].get_len()
  524. else:
  525. self._token_chunk_size = 1
  526. @property
  527. def lora_int_id(self) -> int:
  528. return self.lora_request.lora_int_id if self.lora_request else 0
  529. @property
  530. def token_chunk_size(self) -> int:
  531. """Return the number of tokens to be processed (chunk size)."""
  532. return self._token_chunk_size
  533. class SequenceOutput:
  534. """The model output associated with a sequence.
  535. Args:
  536. parent_seq_id: The ID of the parent sequence (for forking in beam
  537. search).
  538. output_token: The output token ID.
  539. logprobs: The logprobs of the output token.
  540. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  541. persistent_data: The persistent data of the sequence.
  542. """
  543. def __init__(
  544. self,
  545. parent_seq_id: int,
  546. output_token: int,
  547. logprobs: Dict[int, Logprob],
  548. persistent_data: dict,
  549. ) -> None:
  550. self.parent_seq_id = parent_seq_id
  551. self.output_token = output_token
  552. self.logprobs = logprobs
  553. self.persistent_data = persistent_data
  554. def __repr__(self) -> str:
  555. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  556. f"output_token={self.output_token}, "
  557. f"logprobs={self.logprobs}, "
  558. f"persistent_data={self.persistent_data})")
  559. def __eq__(self, other: object) -> bool:
  560. if not isinstance(other, SequenceOutput):
  561. raise NotImplementedError()
  562. equal = (self.parent_seq_id == other.parent_seq_id
  563. and self.output_token == other.output_token)
  564. log_probs_equal = other.logprobs == self.logprobs
  565. return equal and log_probs_equal
  566. class SequenceGroupOutput:
  567. """The model output associated with a sequence group."""
  568. def __init__(
  569. self,
  570. samples: List[SequenceOutput],
  571. prompt_logprobs: Optional[PromptLogprobs],
  572. ) -> None:
  573. self.samples = samples
  574. self.prompt_logprobs = prompt_logprobs
  575. def __repr__(self) -> str:
  576. return (f"SequenceGroupOutput(samples={self.samples}, "
  577. f"prompt_logprobs={self.prompt_logprobs})")
  578. def __eq__(self, other: object) -> bool:
  579. if not isinstance(other, SequenceGroupOutput):
  580. raise NotImplementedError()
  581. return (self.samples == other.samples
  582. and self.prompt_logprobs == other.prompt_logprobs)
  583. @dataclass
  584. class SamplerOutput:
  585. """For each sequence group, we generate a list of SequenceOutput object,
  586. each of which contains one possible candidate for the next token.
  587. This datastructure implements methods so it can be used like a list, but
  588. also has optional fields for device tensors.
  589. """
  590. outputs: List[SequenceGroupOutput]
  591. # On-device tensor containing probabilities of each token.
  592. sampled_token_probs: Optional["torch.Tensor"] = None
  593. # On-device tensor containing the sampled token ids.
  594. sampled_token_ids: Optional["torch.Tensor"] = None
  595. # Spec decode metrics populated by workers.
  596. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  597. def __getitem__(self, idx: int):
  598. return self.outputs[idx]
  599. def __setitem__(self, idx: int, value):
  600. self.outputs[idx] = value
  601. def __len__(self):
  602. return len(self.outputs)
  603. def __eq__(self, other: object):
  604. return (isinstance(other, self.__class__)
  605. and self.outputs == other.outputs)
  606. def __repr__(self) -> str:
  607. """Show the shape of a tensor instead of its values to reduce noise.
  608. """
  609. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  610. else self.sampled_token_probs.shape)
  611. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  612. self.sampled_token_ids.shape)
  613. return (
  614. f"SamplerOutput(outputs={self.outputs}, "
  615. f"sampled_token_probs={sampled_token_probs_repr}, "
  616. f"sampled_token_ids={sampled_token_ids_repr}, "
  617. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")