sequence.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from abc import ABC, abstractmethod
  5. from dataclasses import dataclass, field
  6. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  7. from aphrodite.common.block import LogicalTokenBlock
  8. from aphrodite.common.pooling_params import PoolingParams
  9. from aphrodite.common.sampling_params import SamplingParams
  10. from aphrodite.lora.request import LoRARequest
  11. if TYPE_CHECKING:
  12. import torch
  13. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  14. @dataclass
  15. class Logprob:
  16. """Infos for supporting OpenAI compatible logprobs and token ranks.
  17. Attributes:
  18. logprob: The logprob of chosen token
  19. rank: The vocab rank of chosen token (>=1)
  20. decoded_token: The decoded chosen token index
  21. """
  22. logprob: float
  23. rank: Optional[int] = None
  24. decoded_token: Optional[str] = None
  25. # {token_id -> logprob} per each sequence group. None if the corresponding
  26. # sequence group doesn't require prompt logprob.
  27. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  28. # {token_id -> logprob} for each sequence group.
  29. SampleLogprobs = List[Dict[int, Logprob]]
  30. class SequenceStatus(enum.Enum):
  31. """Status of a sequence."""
  32. WAITING = enum.auto()
  33. RUNNING = enum.auto()
  34. SWAPPED = enum.auto()
  35. FINISHED_STOPPED = enum.auto()
  36. FINISHED_LENGTH_CAPPED = enum.auto()
  37. FINISHED_ABORTED = enum.auto()
  38. FINISHED_IGNORED = enum.auto()
  39. @staticmethod
  40. def is_finished(status: "SequenceStatus") -> bool:
  41. return status in [
  42. SequenceStatus.FINISHED_STOPPED,
  43. SequenceStatus.FINISHED_LENGTH_CAPPED,
  44. SequenceStatus.FINISHED_ABORTED,
  45. SequenceStatus.FINISHED_IGNORED,
  46. ]
  47. @staticmethod
  48. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  49. if status == SequenceStatus.FINISHED_STOPPED:
  50. finish_reason = "stop"
  51. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  52. finish_reason = "length"
  53. elif status == SequenceStatus.FINISHED_ABORTED:
  54. finish_reason = "abort"
  55. elif status == SequenceStatus.FINISHED_IGNORED:
  56. # The ignored sequences are the sequences whose prompt lengths
  57. # are longer than the model's length cap. Therefore, the stop
  58. # reason should also be "length" as in OpenAI API.
  59. finish_reason = "length"
  60. else:
  61. finish_reason = None
  62. return finish_reason
  63. class SequenceStage(enum.Enum):
  64. PREFILL = enum.auto()
  65. DECODE = enum.auto()
  66. @dataclass
  67. class RequestMetrics:
  68. """Metrics associated with a request.
  69. Attributes:
  70. arrival_time: The time when the request arrived.
  71. first_scheduled_time: The time when the request was first scheduled.
  72. first_token_time: The time when the first token was generated.
  73. time_in_queue: The time the request spent in the queue.
  74. finished_time: The time when the request was finished.
  75. """
  76. arrival_time: float
  77. last_token_time: float
  78. first_scheduled_time: Optional[float]
  79. first_token_time: Optional[float]
  80. time_in_queue: Optional[float]
  81. finished_time: Optional[float] = None
  82. class SequenceData:
  83. """Data associated with a sequence.
  84. Args:
  85. prompt_token_ids: The token IDs of the prompt.
  86. output_token_ids: The token IDs of the output. Set to an empty list if
  87. None.
  88. Attributes:
  89. prompt_token_ids: The token IDs of the prompt.
  90. output_token_ids: The token IDs of the output.
  91. cumulative_logprob: The cumulative log probability of the output.
  92. """
  93. def __init__(
  94. self,
  95. prompt_token_ids: List[int],
  96. output_token_ids: Optional[List[int]] = None,
  97. ) -> None:
  98. if output_token_ids is None:
  99. output_token_ids = []
  100. self.prompt_token_ids = prompt_token_ids
  101. self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
  102. self.output_token_ids = output_token_ids
  103. self.cumulative_logprob = 0.0
  104. # The number of tokens that are computed (that run against the model).
  105. self._num_computed_tokens = 0
  106. self._stage: SequenceStage = SequenceStage.PREFILL
  107. def append_token_id(self, token_id: int, logprob: float) -> None:
  108. self.output_token_ids.append(token_id)
  109. self.cumulative_logprob += logprob
  110. def get_len(self) -> int:
  111. return len(self.output_token_ids) + len(self.prompt_token_ids)
  112. def get_prompt_len(self) -> int:
  113. return len(self.prompt_token_ids)
  114. def get_output_len(self) -> int:
  115. return len(self.output_token_ids)
  116. def get_token_ids(self) -> List[int]:
  117. return self.prompt_token_ids + self.output_token_ids
  118. def get_prefix_token_ids(
  119. self, num_tokens: int
  120. ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
  121. """Get prefix tokens, and make the return value hashable"""
  122. prompt_length = len(self.prompt_token_ids)
  123. if num_tokens > prompt_length:
  124. return (self._prompt_token_ids_tuple,
  125. tuple(self.output_token_ids[:num_tokens - prompt_length]))
  126. else:
  127. return (self._prompt_token_ids_tuple[:num_tokens], None)
  128. def get_num_computed_tokens(self) -> int:
  129. """Return the number of prefill tokens that are already computed."""
  130. return self._num_computed_tokens
  131. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  132. """Update number of tokens computed so far."""
  133. self._num_computed_tokens += num_new_computed_tokens
  134. assert self._num_computed_tokens <= self.get_len(), (
  135. self._num_computed_tokens, self.get_len())
  136. # If all tokens are computed, it means it is in decoding phase.
  137. if self.get_num_uncomputed_tokens() == 0:
  138. self._stage = SequenceStage.DECODE
  139. def reset_state_for_recompute(self) -> None:
  140. """Reset the number of computed tokens from this sequence. It is
  141. supposed to be called when a sequence needs to be started from
  142. the beginning again (e.g., sequence is preempted).
  143. """
  144. self._num_computed_tokens = 0
  145. self._stage = SequenceStage.PREFILL
  146. def get_num_uncomputed_tokens(self) -> int:
  147. """Return the number of prefill tokens that are not computed."""
  148. # we use `get_len()` which includes prompt_len + output_len instead
  149. # of prompt_len here. This is because during recompute we need to
  150. # prefill for both prompt and output.
  151. return self.get_len() - self.get_num_computed_tokens()
  152. def get_last_token_id(self) -> int:
  153. if not self.output_token_ids:
  154. return self.prompt_token_ids[-1]
  155. return self.output_token_ids[-1]
  156. def get_prompt_token_ids(self) -> List[int]:
  157. return self.prompt_token_ids
  158. def get_output_token_ids(self) -> List[int]:
  159. return self.output_token_ids
  160. @property
  161. def stage(self) -> SequenceStage:
  162. return self._stage
  163. def __repr__(self) -> str:
  164. return (f"SequenceData("
  165. f"prompt_token_ids={self.prompt_token_ids}, "
  166. f"output_token_ids={self.output_token_ids}, "
  167. f"cumulative_logprob={self.cumulative_logprob})")
  168. class Sequence:
  169. """Stores the data, status, and block information of a sequence.
  170. Args:
  171. seq_id: The ID of the sequence.
  172. prompt: The prompt of the sequence.
  173. prompt_token_ids: The token IDs of the prompt.
  174. block_size: The block size of the sequence. Should be the same as the
  175. block size used by the block manager and cache engine.
  176. lora_request: LoRA request.
  177. """
  178. def __init__(
  179. self,
  180. seq_id: int,
  181. prompt: str,
  182. prompt_token_ids: List[int],
  183. block_size: int,
  184. eos_token_id: Optional[int] = None,
  185. lora_request: Optional[LoRARequest] = None,
  186. ) -> None:
  187. self.seq_id = seq_id
  188. self.prompt = prompt
  189. self.block_size = block_size
  190. self.eos_token_id = eos_token_id
  191. self.lora_request = lora_request
  192. self.data: SequenceData = SequenceData(prompt_token_ids)
  193. self.output_logprobs: SampleLogprobs = []
  194. self.output_text = ""
  195. self.logical_token_blocks: List[LogicalTokenBlock] = []
  196. # Initialize the logical token blocks with the prompt token ids.
  197. self._append_tokens_to_blocks(prompt_token_ids)
  198. self.status = SequenceStatus.WAITING
  199. self.stop_reason: Union[int, str, None] = None
  200. # Used for incremental detokenization
  201. self.prefix_offset = 0
  202. self.read_offset = 0
  203. # Input + output tokens
  204. self.tokens: Optional[List[str]] = None
  205. @property
  206. def lora_int_id(self) -> int:
  207. return self.lora_request.lora_int_id if self.lora_request else 0
  208. def get_output_text_to_return(self, buffer_length: int):
  209. # We return the full output text if the sequence is finished.
  210. truncate = buffer_length and not self.is_finished()
  211. return self.output_text[:-buffer_length] if truncate else (
  212. self.output_text)
  213. def hash_of_block(self, logical_idx: int) -> int:
  214. # TODO This can produce incorrect hash when block size > prompt size
  215. # Compute the number of tokens in the sequence
  216. # TODO: The current hashing function is O(L^2). We should optimize
  217. # this in the future.
  218. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  219. hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
  220. return hash((hashed_tokens, self.lora_int_id))
  221. def num_hashed_tokens_of_block(self, logical_idx: int):
  222. return logical_idx * self.block_size + self.block_size
  223. def reset_state_for_recompute(self):
  224. """Reset the sequence states for recomputation."""
  225. self.data.reset_state_for_recompute()
  226. def _append_logical_block(self) -> None:
  227. block = LogicalTokenBlock(
  228. block_number=len(self.logical_token_blocks),
  229. block_size=self.block_size,
  230. )
  231. self.logical_token_blocks.append(block)
  232. def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
  233. cursor = 0
  234. while cursor < len(token_ids):
  235. if not self.logical_token_blocks:
  236. self._append_logical_block()
  237. last_block = self.logical_token_blocks[-1]
  238. if last_block.is_full():
  239. self._append_logical_block()
  240. last_block = self.logical_token_blocks[-1]
  241. num_empty_slots = last_block.get_num_empty_slots()
  242. last_block.append_tokens(token_ids[cursor:cursor +
  243. num_empty_slots])
  244. cursor += num_empty_slots
  245. def append_token_id(
  246. self,
  247. token_id: int,
  248. logprobs: Dict[int, Logprob],
  249. ) -> None:
  250. assert token_id in logprobs
  251. self._append_tokens_to_blocks([token_id])
  252. self.output_logprobs.append(logprobs)
  253. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  254. def get_len(self) -> int:
  255. return self.data.get_len()
  256. def get_prompt_len(self) -> int:
  257. return self.data.get_prompt_len()
  258. def get_output_len(self) -> int:
  259. return self.data.get_output_len()
  260. def get_token_ids(self) -> List[int]:
  261. return self.data.get_token_ids()
  262. def get_prompt_token_ids(self) -> List[int]:
  263. return self.data.get_prompt_token_ids()
  264. def get_last_token_id(self) -> int:
  265. return self.data.get_last_token_id()
  266. def get_output_token_ids(self) -> List[int]:
  267. return self.data.output_token_ids
  268. def get_cumulative_logprob(self) -> float:
  269. return self.data.cumulative_logprob
  270. def get_beam_search_score(self,
  271. length_penalty: float = 1.0,
  272. seq_len: Optional[int] = None,
  273. eos_token_id: Optional[int] = None) -> float:
  274. """Calculate the beam search score with length penalty.
  275. Adapted from
  276. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  277. """
  278. if seq_len is None:
  279. seq_len = self.get_len()
  280. # NOTE: HF implementation does not count the EOS token
  281. # towards the length, we align with that here for testing.
  282. if (eos_token_id is not None
  283. and self.get_last_token_id() == eos_token_id):
  284. seq_len -= 1
  285. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  286. def is_finished(self) -> bool:
  287. return SequenceStatus.is_finished(self.status)
  288. def fork(self, new_seq_id: int) -> "Sequence":
  289. new_seq = copy.deepcopy(self)
  290. new_seq.seq_id = new_seq_id
  291. return new_seq
  292. def get_num_new_tokens(self) -> int:
  293. """Get the number of new tokens to be computed.
  294. Returns:
  295. The new number of tokens to be computed. I.e., 1 for decode, or
  296. the remaining prompt size for prefill.
  297. """
  298. if self.data.stage == SequenceStage.DECODE:
  299. return 1
  300. return self.data.get_num_uncomputed_tokens()
  301. def is_prefill(self) -> bool:
  302. return self.data.stage == SequenceStage.PREFILL
  303. def __repr__(self) -> str:
  304. return (f"Sequence(seq_id={self.seq_id}, "
  305. f"status={self.status.name}, "
  306. f"num_blocks={len(self.logical_token_blocks)})")
  307. @dataclass
  308. class SequenceGroupState:
  309. """Mutable state tied to a specific sequence group"""
  310. # torch.Generator used in seeded sampling
  311. generator: Optional = None # type: ignore
  312. class MultiModalData:
  313. """Multi modal request.
  314. Args:
  315. type: The data type.
  316. data: The actual data.
  317. The required shape and semantic meaning of it depends on the vision
  318. language config of the hosted model.
  319. See `VisionLanguageConfig` in `config.py`.
  320. """
  321. class Type(enum.Enum):
  322. IMAGE = enum.auto()
  323. def __init__(self, type: Type, data: "torch.Tensor"):
  324. self.type = type
  325. self.data = data
  326. class SequenceGroup:
  327. """A group of sequences that are generated from the same prompt.
  328. Args:
  329. request_id: The ID of the request.
  330. seqs: The list of sequences.
  331. sampling_params: The sampling parameters used to generate the outputs.
  332. arrival_time: The arrival time of the request.
  333. lora_request: LoRA request.
  334. multi_modal_data: Multi modal data associated with the request.
  335. embeddings: The embeddings vectors of the prompt of the sequence group
  336. for an embedding model.
  337. pooling_params: The pooling parameters used to generate the pooling
  338. for an embedding model.
  339. """
  340. def __init__(
  341. self,
  342. request_id: str,
  343. seqs: List[Sequence],
  344. arrival_time: float,
  345. sampling_params: Optional[SamplingParams] = None,
  346. lora_request: Optional[LoRARequest] = None,
  347. multi_modal_data: Optional[MultiModalData] = None,
  348. embeddings: Optional[List[float]] = None,
  349. pooling_params: Optional[PoolingParams] = None,
  350. ) -> None:
  351. self.request_id = request_id
  352. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  353. self.sampling_params = sampling_params
  354. self.metrics = RequestMetrics(arrival_time=arrival_time,
  355. last_token_time=arrival_time,
  356. first_scheduled_time=None,
  357. first_token_time=None,
  358. time_in_queue=None)
  359. self.lora_request = lora_request
  360. self.prompt_logprobs: Optional[PromptLogprobs] = None
  361. self.state = SequenceGroupState()
  362. self.multi_modal_data = multi_modal_data
  363. self.embeddings = embeddings
  364. self.pooling_params = pooling_params
  365. @property
  366. def prompt(self) -> str:
  367. # All sequences in the group should have the same prompt.
  368. # We use the prompt of an arbitrary sequence.
  369. return next(iter(self.seqs_dict.values())).prompt
  370. @property
  371. def prompt_token_ids(self) -> List[int]:
  372. # All sequences in the group should have the same prompt.
  373. # We use the prompt of an arbitrary sequence.
  374. return next(iter(self.seqs_dict.values())).data.prompt_token_ids
  375. @property
  376. def lora_int_id(self) -> int:
  377. return self.lora_request.lora_int_id if self.lora_request else 0
  378. def get_last_latency(self, now: float) -> Optional[float]:
  379. """Sets the last token time for Request level timings."""
  380. # If still in prefill phase, raise Error.
  381. if self.is_prefill():
  382. raise ValueError(
  383. "seq_group.get_last_latency() should not be called "
  384. "if the seq_group is in prefill phase.")
  385. # Otherwise return token latency.
  386. latency = now - self.metrics.last_token_time
  387. self.metrics.last_token_time = now
  388. return latency
  389. def maybe_set_first_token_time(self, time: float) -> None:
  390. """Sets the first token time for Request level timings."""
  391. # NOTE: in a case where a sequence_group is swapped and
  392. # recomputed, the time between iterations is counted
  393. # in TPOT, rather than recalculating TTFT (since from the )
  394. # POV of the user, there is simply a long generation delay.
  395. if (self.metrics.first_token_time is None
  396. and self.get_seqs()[0].get_output_len() == 1):
  397. self.metrics.first_token_time = time
  398. def maybe_set_first_scheduled_time(self, time: float) -> None:
  399. """Sets the first scheduled time and time in queue for Request
  400. level timings."""
  401. if self.metrics.first_scheduled_time is None:
  402. self.metrics.first_scheduled_time = time
  403. self.metrics.time_in_queue = time - self.metrics.arrival_time
  404. def set_finished_time(self, time: Optional[float]) -> None:
  405. """Sets the finished time for Request level timings."""
  406. self.metrics.finished_time = time
  407. def get_max_num_running_seqs(self) -> int:
  408. """The maximum number of sequences running in parallel in the remaining
  409. lifetime of the request."""
  410. if self.sampling_params and self.sampling_params.use_beam_search:
  411. # For beam search, maximally there will always be `best_of` beam
  412. # candidates running in the future.
  413. return self.sampling_params.best_of
  414. else:
  415. if (self.sampling_params
  416. and self.sampling_params.best_of > self.num_seqs()):
  417. # At prompt stage, the sequence group is not yet filled up
  418. # and only have one sequence running. However, in the
  419. # generation stage, we will have `best_of` sequences running.
  420. return self.sampling_params.best_of
  421. # At sampling stages, return the number of actual sequences
  422. # that are not finished yet.
  423. return self.num_unfinished_seqs()
  424. def get_seqs(
  425. self,
  426. status: Optional[SequenceStatus] = None,
  427. ) -> List[Sequence]:
  428. return list(self.seqs_dict.values()) if status is None else [
  429. seq for seq in self.seqs_dict.values() if seq.status == status
  430. ]
  431. def get_unfinished_seqs(self) -> List[Sequence]:
  432. return [
  433. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  434. ]
  435. def get_finished_seqs(self) -> List[Sequence]:
  436. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  437. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  438. """Update number of tokens computed so far."""
  439. for seq in self.seqs_dict.values():
  440. if not seq.is_finished():
  441. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  442. def get_num_uncomputed_tokens(self) -> int:
  443. num_uncomputed_tokens = 0
  444. for seq in self.get_seqs():
  445. if not seq.is_finished():
  446. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  447. return num_uncomputed_tokens
  448. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  449. # Optimization. We don't need to call get_seqs if we don't need to
  450. # filter by states.
  451. if status is None:
  452. return len(self.seqs_dict)
  453. return len(self.get_seqs(status))
  454. def num_unfinished_seqs(self) -> int:
  455. return len(self.get_unfinished_seqs())
  456. def num_finished_seqs(self) -> int:
  457. return len(self.get_finished_seqs())
  458. def find(self, seq_id: int) -> Sequence:
  459. if seq_id not in self.seqs_dict:
  460. raise ValueError(f"Sequence {seq_id} not found.")
  461. return self.seqs_dict[seq_id]
  462. def add(self, seq: Sequence) -> None:
  463. if seq.seq_id in self.seqs_dict:
  464. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  465. self.seqs_dict[seq.seq_id] = seq
  466. def remove(self, seq_id: int) -> None:
  467. if seq_id not in self.seqs_dict:
  468. raise ValueError(f"Sequence {seq_id} not found.")
  469. del self.seqs_dict[seq_id]
  470. def is_finished(self) -> bool:
  471. return all(seq.is_finished() for seq in self.get_seqs())
  472. def is_prefill(self) -> bool:
  473. # Every sequence should be in the same stage.
  474. return self.get_seqs()[0].is_prefill()
  475. def __repr__(self) -> str:
  476. return (f"SequenceGroup(request_id={self.request_id}, "
  477. f"sampling_params={self.sampling_params}, "
  478. f"num_seqs={len(self.seqs_dict)})")
  479. class SequenceGroupMetadata:
  480. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  481. Args:
  482. request_id: The ID of the request.
  483. is_prompt: Whether the request is at prompt stage.
  484. seq_data: The sequence data. (Seq id -> sequence data)
  485. sampling_params: The sampling parameters used to generate the outputs.
  486. block_tables: The block tables. (Seq id -> list of physical block
  487. numbers)
  488. do_sample: True if sampling is required. Sampling is not required when
  489. e.g., prefill is chunked, and the current iteration only computes
  490. query tokens for prefill, we don't need sampling.
  491. pooling_params: The pooling parameters used to generate the pooling.
  492. token_chunk_size: The number of tokens to be processed (per sequence).
  493. None if chunking is not required.
  494. lora_request: LoRA request.
  495. computed_block_nums: The block numbers that are already computed,
  496. used in prefix caching.
  497. state: Internal state tied to this sequence group.
  498. multi_modal_data: Multi modal data.
  499. """
  500. def __init__(
  501. self,
  502. request_id: str,
  503. is_prompt: bool,
  504. seq_data: Dict[int, SequenceData],
  505. sampling_params: SamplingParams,
  506. block_tables: Dict[int, List[int]],
  507. do_sample: bool = True,
  508. pooling_params: Optional[PoolingParams] = None,
  509. token_chunk_size: Optional[int] = None,
  510. lora_request: Optional[LoRARequest] = None,
  511. computed_block_nums: Optional[List[int]] = None,
  512. state: Optional[SequenceGroupState] = None,
  513. multi_modal_data: Optional[MultiModalData] = None,
  514. ) -> None:
  515. self.request_id = request_id
  516. self.is_prompt = is_prompt
  517. self.seq_data = seq_data
  518. self.sampling_params = sampling_params
  519. self.block_tables = block_tables
  520. self.pooling_params = pooling_params
  521. self.lora_request = lora_request
  522. self.computed_block_nums = computed_block_nums
  523. self.multi_modal_data = multi_modal_data
  524. self.state = SequenceGroupState() if state is None else state
  525. self._token_chunk_size = token_chunk_size
  526. self.do_sample = do_sample
  527. # The number of speculative tokens adopted in this request.
  528. # None means specuative decoding is not used.
  529. # Zero means speculative decoding is disabled for some reasons.
  530. # TODO: We should maintain this states out of the sequence group.
  531. self.num_speculative_tokens = None
  532. if self._token_chunk_size is None:
  533. if is_prompt:
  534. self._token_chunk_size = list(seq_data.values())[0].get_len()
  535. else:
  536. self._token_chunk_size = 1
  537. @property
  538. def lora_int_id(self) -> int:
  539. return self.lora_request.lora_int_id if self.lora_request else 0
  540. @property
  541. def token_chunk_size(self) -> int:
  542. """Return the number of tokens to be processed (chunk size)."""
  543. assert self._token_chunk_size is not None
  544. return self._token_chunk_size
  545. class SequenceOutput:
  546. """The model output associated with a sequence.
  547. Args:
  548. parent_seq_id: The ID of the parent sequence (for forking in beam
  549. search).
  550. output_token: The output token ID.
  551. logprobs: The logprobs of the output token.
  552. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  553. """
  554. def __init__(
  555. self,
  556. parent_seq_id: int,
  557. output_token: int,
  558. logprobs: Dict[int, Logprob],
  559. ) -> None:
  560. self.parent_seq_id = parent_seq_id
  561. self.output_token = output_token
  562. self.logprobs = logprobs
  563. def __repr__(self) -> str:
  564. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  565. f"output_token={self.output_token}, "
  566. f"logprobs={self.logprobs})")
  567. def __eq__(self, other: object) -> bool:
  568. if not isinstance(other, SequenceOutput):
  569. raise NotImplementedError()
  570. equal = (self.parent_seq_id == other.parent_seq_id
  571. and self.output_token == other.output_token)
  572. log_probs_equal = other.logprobs == self.logprobs
  573. return equal and log_probs_equal
  574. class SequenceGroupOutput(ABC):
  575. """The base class for model outputs associated with a sequence group."""
  576. @abstractmethod
  577. def __repr__(self) -> str:
  578. pass
  579. @abstractmethod
  580. def __eq__(self, other: object) -> bool:
  581. pass
  582. class CompletionSequenceGroupOutput(SequenceGroupOutput):
  583. """The model output associated with a completion sequence group."""
  584. def __init__(
  585. self,
  586. samples: List[SequenceOutput],
  587. prompt_logprobs: Optional[PromptLogprobs],
  588. ) -> None:
  589. self.samples = samples
  590. # Prompt logprob for each prompt query token.
  591. self.prompt_logprobs = prompt_logprobs
  592. def __repr__(self) -> str:
  593. return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
  594. f"prompt_logprobs={self.prompt_logprobs})")
  595. def __eq__(self, other: object) -> bool:
  596. if not isinstance(other, CompletionSequenceGroupOutput):
  597. raise NotImplementedError()
  598. return (self.samples == other.samples
  599. and self.prompt_logprobs == other.prompt_logprobs)
  600. class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
  601. """The model output associated with an embedding sequence group."""
  602. def __init__(
  603. self,
  604. embeddings: List[float],
  605. ) -> None:
  606. self.embeddings = embeddings
  607. def __repr__(self) -> str:
  608. return (f"EmbeddingSequenceGroupOutput("
  609. f"embeddings_shape={len(self.embeddings)})")
  610. def __eq__(self, other: object) -> bool:
  611. if not isinstance(other, EmbeddingSequenceGroupOutput):
  612. raise NotImplementedError()
  613. return self.embeddings == other.embeddings
  614. @dataclass
  615. class SamplerOutput:
  616. """For each sequence group, we generate a list of SequenceOutput object,
  617. each of which contains one possible candidate for the next token.
  618. This data structure implements methods, so it can be used like a list, but
  619. also has optional fields for device tensors.
  620. """
  621. outputs: List[CompletionSequenceGroupOutput]
  622. # On-device tensor containing probabilities of each token.
  623. sampled_token_probs: Optional["torch.Tensor"] = None
  624. # On-device tensor containing the logprobs of each token.
  625. logprobs: Optional["torch.Tensor"] = None
  626. # On-device tensor containing the sampled token ids.
  627. sampled_token_ids: Optional["torch.Tensor"] = None
  628. # Spec decode metrics populated by workers.
  629. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  630. def __getitem__(self, idx: int):
  631. return self.outputs[idx]
  632. def __setitem__(self, idx: int, value):
  633. self.outputs[idx] = value
  634. def __len__(self):
  635. return len(self.outputs)
  636. def __eq__(self, other: object):
  637. return isinstance(other,
  638. self.__class__) and self.outputs == other.outputs
  639. def __repr__(self) -> str:
  640. """Show the shape of a tensor instead of its values to reduce noise.
  641. """
  642. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  643. else self.sampled_token_probs.shape)
  644. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  645. self.sampled_token_ids.shape)
  646. return (
  647. f"SamplerOutput(outputs={self.outputs}, "
  648. f"sampled_token_probs={sampled_token_probs_repr}, "
  649. f"sampled_token_ids={sampled_token_ids_repr}, "
  650. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
  651. @dataclass
  652. class PoolerOutput:
  653. """The output from a pooling operation in the embedding model."""
  654. outputs: List[EmbeddingSequenceGroupOutput]
  655. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  656. def __getitem__(self, idx: int):
  657. return self.outputs[idx]
  658. def __setitem__(self, idx: int, value):
  659. self.outputs[idx] = value
  660. def __len__(self):
  661. return len(self.outputs)
  662. def __eq__(self, other: object):
  663. return isinstance(other,
  664. self.__class__) and self.outputs == other.outputs
  665. @dataclass
  666. class ExecuteModelRequest:
  667. """The model execution request."""
  668. # The sequence group metadata list.
  669. seq_group_metadata_list: List[SequenceGroupMetadata]
  670. # Blocks to swap in. List of CPU -> GPU block number.
  671. blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
  672. # Blocks to swap out. List of GPU -> CPU block number.
  673. blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
  674. # Blocks to copy. Source to dest block.
  675. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
  676. # The number of slots for lookahead decoding.
  677. num_lookahead_slots: int = 0
  678. # The number of requests in the running queue.
  679. running_queue_size: int = 0
  680. def clone(
  681. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  682. ) -> "ExecuteModelRequest":
  683. """Clone the request with a new sequence group metadata list."""
  684. return ExecuteModelRequest(
  685. seq_group_metadata_list=seq_group_metadata_list,
  686. blocks_to_swap_in=self.blocks_to_swap_in.copy(),
  687. blocks_to_swap_out=self.blocks_to_swap_out.copy(),
  688. blocks_to_copy=self.blocks_to_copy.copy(),
  689. num_lookahead_slots=self.num_lookahead_slots,
  690. running_queue_size=self.running_queue_size,
  691. )