sequence.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from abc import ABC, abstractmethod
  5. from dataclasses import dataclass, field
  6. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  7. import torch
  8. from aphrodite.common.block import LogicalTokenBlock
  9. from aphrodite.common.inputs import LLMInputs
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.lora.request import LoRARequest
  13. if TYPE_CHECKING:
  14. from aphrodite.multimodal import MultiModalData
  15. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  16. @dataclass
  17. class Logprob:
  18. """Infos for supporting OpenAI compatible logprobs and token ranks.
  19. Attributes:
  20. logprob: The logprob of chosen token
  21. rank: The vocab rank of chosen token (>=1)
  22. decoded_token: The decoded chosen token index
  23. """
  24. logprob: float
  25. rank: Optional[int] = None
  26. decoded_token: Optional[str] = None
  27. # {token_id -> logprob} per each sequence group. None if the corresponding
  28. # sequence group doesn't require prompt logprob.
  29. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  30. # {token_id -> logprob} for each sequence group.
  31. SampleLogprobs = List[Dict[int, Logprob]]
  32. class SequenceStatus(enum.Enum):
  33. """Status of a sequence."""
  34. WAITING = enum.auto()
  35. RUNNING = enum.auto()
  36. SWAPPED = enum.auto()
  37. FINISHED_STOPPED = enum.auto()
  38. FINISHED_LENGTH_CAPPED = enum.auto()
  39. FINISHED_ABORTED = enum.auto()
  40. FINISHED_IGNORED = enum.auto()
  41. @staticmethod
  42. def is_finished(status: "SequenceStatus") -> bool:
  43. return status in [
  44. SequenceStatus.FINISHED_STOPPED,
  45. SequenceStatus.FINISHED_LENGTH_CAPPED,
  46. SequenceStatus.FINISHED_ABORTED,
  47. SequenceStatus.FINISHED_IGNORED,
  48. ]
  49. @staticmethod
  50. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  51. if status == SequenceStatus.FINISHED_STOPPED:
  52. finish_reason = "stop"
  53. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  54. finish_reason = "length"
  55. elif status == SequenceStatus.FINISHED_ABORTED:
  56. finish_reason = "abort"
  57. elif status == SequenceStatus.FINISHED_IGNORED:
  58. # The ignored sequences are the sequences whose prompt lengths
  59. # are longer than the model's length cap. Therefore, the stop
  60. # reason should also be "length" as in OpenAI API.
  61. finish_reason = "length"
  62. else:
  63. finish_reason = None
  64. return finish_reason
  65. class SequenceStage(enum.Enum):
  66. PREFILL = enum.auto()
  67. DECODE = enum.auto()
  68. @dataclass
  69. class RequestMetrics:
  70. """Metrics associated with a request.
  71. Attributes:
  72. arrival_time: The time when the request arrived.
  73. first_scheduled_time: The time when the request was first scheduled.
  74. first_token_time: The time when the first token was generated.
  75. time_in_queue: The time the request spent in the queue.
  76. finished_time: The time when the request was finished.
  77. """
  78. arrival_time: float
  79. last_token_time: float
  80. first_scheduled_time: Optional[float]
  81. first_token_time: Optional[float]
  82. time_in_queue: Optional[float]
  83. finished_time: Optional[float] = None
  84. class SequenceData:
  85. """Data associated with a sequence.
  86. Args:
  87. prompt_token_ids: The token IDs of the prompt.
  88. output_token_ids: The token IDs of the output. Set to an empty list if
  89. None.
  90. Attributes:
  91. prompt_token_ids: The token IDs of the prompt.
  92. output_token_ids: The token IDs of the output.
  93. cumulative_logprob: The cumulative log probability of the output.
  94. """
  95. def __init__(
  96. self,
  97. prompt_token_ids: List[int],
  98. output_token_ids: Optional[List[int]] = None,
  99. ) -> None:
  100. if output_token_ids is None:
  101. output_token_ids = []
  102. self.prompt_token_ids = prompt_token_ids
  103. self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
  104. self.output_token_ids = output_token_ids
  105. self.cumulative_logprob = 0.0
  106. # The number of tokens that are computed (that run against the model).
  107. self._num_computed_tokens = 0
  108. self._stage: SequenceStage = SequenceStage.PREFILL
  109. def append_token_id(self, token_id: int, logprob: float) -> None:
  110. self.output_token_ids.append(token_id)
  111. self.cumulative_logprob += logprob
  112. def get_len(self) -> int:
  113. return len(self.output_token_ids) + len(self.prompt_token_ids)
  114. def get_prompt_len(self) -> int:
  115. return len(self.prompt_token_ids)
  116. def get_output_len(self) -> int:
  117. return len(self.output_token_ids)
  118. def get_token_ids(self) -> List[int]:
  119. return self.prompt_token_ids + self.output_token_ids
  120. def get_prefix_token_ids(
  121. self, num_tokens: int
  122. ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
  123. """Get prefix tokens, and make the return value hashable"""
  124. prompt_length = len(self.prompt_token_ids)
  125. if num_tokens > prompt_length:
  126. return (self._prompt_token_ids_tuple,
  127. tuple(self.output_token_ids[:num_tokens - prompt_length]))
  128. else:
  129. return (self._prompt_token_ids_tuple[:num_tokens], None)
  130. def get_num_computed_tokens(self) -> int:
  131. """Return the number of prefill tokens that are already computed."""
  132. return self._num_computed_tokens
  133. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  134. """Update number of tokens computed so far."""
  135. self._num_computed_tokens += num_new_computed_tokens
  136. assert self._num_computed_tokens <= self.get_len(), (
  137. self._num_computed_tokens, self.get_len())
  138. # If all tokens are computed, it means it is in decoding phase.
  139. if self.get_num_uncomputed_tokens() == 0:
  140. self._stage = SequenceStage.DECODE
  141. def reset_state_for_recompute(self) -> None:
  142. """Reset the number of computed tokens from this sequence. It is
  143. supposed to be called when a sequence needs to be started from
  144. the beginning again (e.g., sequence is preempted).
  145. """
  146. self._num_computed_tokens = 0
  147. self._stage = SequenceStage.PREFILL
  148. def get_num_uncomputed_tokens(self) -> int:
  149. """Return the number of prefill tokens that are not computed."""
  150. # we use `get_len()` which includes prompt_len + output_len instead
  151. # of prompt_len here. This is because during recompute we need to
  152. # prefill for both prompt and output.
  153. return self.get_len() - self.get_num_computed_tokens()
  154. def get_last_token_id(self) -> int:
  155. if not self.output_token_ids:
  156. return self.prompt_token_ids[-1]
  157. return self.output_token_ids[-1]
  158. def get_prompt_token_ids(self) -> List[int]:
  159. return self.prompt_token_ids
  160. def get_output_token_ids(self) -> List[int]:
  161. return self.output_token_ids
  162. @property
  163. def stage(self) -> SequenceStage:
  164. return self._stage
  165. def __repr__(self) -> str:
  166. return (f"SequenceData("
  167. f"prompt_token_ids={self.prompt_token_ids}, "
  168. f"output_token_ids={self.output_token_ids}, "
  169. f"cumulative_logprob={self.cumulative_logprob})")
  170. class Sequence:
  171. """Stores the data, status, and block information of a sequence.
  172. Args:
  173. seq_id: The ID of the sequence.
  174. prompt: The prompt of the sequence.
  175. prompt_token_ids: The token IDs of the prompt.
  176. block_size: The block size of the sequence. Should be the same as the
  177. block size used by the block manager and cache engine.
  178. lora_request: LoRA request.
  179. """
  180. def __init__(
  181. self,
  182. seq_id: int,
  183. inputs: LLMInputs,
  184. block_size: int,
  185. eos_token_id: Optional[int] = None,
  186. lora_request: Optional[LoRARequest] = None,
  187. ) -> None:
  188. self.seq_id = seq_id
  189. self.inputs = inputs
  190. self.block_size = block_size
  191. self.eos_token_id = eos_token_id
  192. self.lora_request = lora_request
  193. self.data = SequenceData(self.prompt_token_ids)
  194. self.output_logprobs: SampleLogprobs = []
  195. self.output_text = ""
  196. self.logical_token_blocks: List[LogicalTokenBlock] = []
  197. # Initialize the logical token blocks with the prompt token ids.
  198. self._append_tokens_to_blocks(self.prompt_token_ids)
  199. self.status = SequenceStatus.WAITING
  200. self.stop_reason: Union[int, str, None] = None
  201. # Used for incremental detokenization
  202. self.prefix_offset = 0
  203. self.read_offset = 0
  204. # Input + output tokens
  205. self.tokens: Optional[List[str]] = None
  206. @property
  207. def prompt(self) -> Optional[str]:
  208. return self.inputs.get("prompt")
  209. @property
  210. def prompt_token_ids(self) -> List[int]:
  211. return self.inputs["prompt_token_ids"]
  212. @property
  213. def multi_modal_data(self) -> Optional["MultiModalData"]:
  214. return self.inputs.get("multi_modal_data")
  215. @property
  216. def lora_int_id(self) -> int:
  217. return self.lora_request.lora_int_id if self.lora_request else 0
  218. def get_output_text_to_return(self, buffer_length: int):
  219. # We return the full output text if the sequence is finished.
  220. truncate = buffer_length and not self.is_finished()
  221. return self.output_text[:-buffer_length] if truncate else (
  222. self.output_text)
  223. def hash_of_block(self, logical_idx: int) -> int:
  224. # TODO This can produce incorrect hash when block size > prompt size
  225. # Compute the number of tokens in the sequence
  226. # TODO: The current hashing function is O(L^2). We should optimize
  227. # this in the future.
  228. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  229. hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
  230. return hash((hashed_tokens, self.lora_int_id))
  231. def num_hashed_tokens_of_block(self, logical_idx: int):
  232. return logical_idx * self.block_size + self.block_size
  233. def reset_state_for_recompute(self):
  234. """Reset the sequence states for recomputation."""
  235. self.data.reset_state_for_recompute()
  236. def _append_logical_block(self) -> None:
  237. block = LogicalTokenBlock(
  238. block_number=len(self.logical_token_blocks),
  239. block_size=self.block_size,
  240. )
  241. self.logical_token_blocks.append(block)
  242. def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
  243. cursor = 0
  244. while cursor < len(token_ids):
  245. if not self.logical_token_blocks:
  246. self._append_logical_block()
  247. last_block = self.logical_token_blocks[-1]
  248. if last_block.is_full():
  249. self._append_logical_block()
  250. last_block = self.logical_token_blocks[-1]
  251. num_empty_slots = last_block.get_num_empty_slots()
  252. last_block.append_tokens(token_ids[cursor:cursor +
  253. num_empty_slots])
  254. cursor += num_empty_slots
  255. def append_token_id(
  256. self,
  257. token_id: int,
  258. logprobs: Dict[int, Logprob],
  259. ) -> None:
  260. assert token_id in logprobs
  261. self._append_tokens_to_blocks([token_id])
  262. self.output_logprobs.append(logprobs)
  263. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  264. def get_len(self) -> int:
  265. return self.data.get_len()
  266. def get_prompt_len(self) -> int:
  267. return self.data.get_prompt_len()
  268. def get_output_len(self) -> int:
  269. return self.data.get_output_len()
  270. def get_token_ids(self) -> List[int]:
  271. return self.data.get_token_ids()
  272. def get_prompt_token_ids(self) -> List[int]:
  273. return self.data.get_prompt_token_ids()
  274. def get_last_token_id(self) -> int:
  275. return self.data.get_last_token_id()
  276. def get_output_token_ids(self) -> List[int]:
  277. return self.data.output_token_ids
  278. def get_cumulative_logprob(self) -> float:
  279. return self.data.cumulative_logprob
  280. def get_beam_search_score(self,
  281. length_penalty: float = 1.0,
  282. seq_len: Optional[int] = None,
  283. eos_token_id: Optional[int] = None) -> float:
  284. """Calculate the beam search score with length penalty.
  285. Adapted from
  286. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  287. """
  288. if seq_len is None:
  289. seq_len = self.get_len()
  290. # NOTE: HF implementation does not count the EOS token
  291. # towards the length, we align with that here for testing.
  292. if (eos_token_id is not None
  293. and self.get_last_token_id() == eos_token_id):
  294. seq_len -= 1
  295. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  296. def is_finished(self) -> bool:
  297. return SequenceStatus.is_finished(self.status)
  298. def fork(self, new_seq_id: int) -> "Sequence":
  299. new_seq = copy.deepcopy(self)
  300. new_seq.seq_id = new_seq_id
  301. return new_seq
  302. def get_num_new_tokens(self) -> int:
  303. """Get the number of new tokens to be computed.
  304. Returns:
  305. The new number of tokens to be computed. I.e., 1 for decode, or
  306. the remaining prompt size for prefill.
  307. """
  308. if self.data.stage == SequenceStage.DECODE:
  309. return 1
  310. return self.data.get_num_uncomputed_tokens()
  311. def is_prefill(self) -> bool:
  312. return self.data.stage == SequenceStage.PREFILL
  313. def __repr__(self) -> str:
  314. return (f"Sequence(seq_id={self.seq_id}, "
  315. f"status={self.status.name}, "
  316. f"num_blocks={len(self.logical_token_blocks)})")
  317. @dataclass
  318. class SequenceGroupState:
  319. """Mutable state tied to a specific sequence group"""
  320. # torch.Generator used in seeded sampling
  321. generator: Optional = None # type: ignore
  322. class SequenceGroup:
  323. """A group of sequences that are generated from the same prompt.
  324. Args:
  325. request_id: The ID of the request.
  326. seqs: The list of sequences.
  327. sampling_params: The sampling parameters used to generate the outputs.
  328. arrival_time: The arrival time of the request.
  329. lora_request: LoRA request.
  330. multi_modal_data: Multi modal data associated with the request.
  331. embeddings: The embeddings vectors of the prompt of the sequence group
  332. for an embedding model.
  333. pooling_params: The pooling parameters used to generate the pooling
  334. for an embedding model.
  335. encoder_seq: Optional, the single encoder sequence. Should be None
  336. unless you are working with an encoder/decoder model.
  337. """
  338. def __init__(
  339. self,
  340. request_id: str,
  341. seqs: List[Sequence],
  342. arrival_time: float,
  343. sampling_params: Optional[SamplingParams] = None,
  344. lora_request: Optional[LoRARequest] = None,
  345. embeddings: Optional[List[float]] = None,
  346. pooling_params: Optional[PoolingParams] = None,
  347. encoder_seq: Optional[Sequence] = None,
  348. ) -> None:
  349. self.request_id = request_id
  350. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  351. self.sampling_params = sampling_params
  352. self.metrics = RequestMetrics(arrival_time=arrival_time,
  353. last_token_time=arrival_time,
  354. first_scheduled_time=None,
  355. first_token_time=None,
  356. time_in_queue=None)
  357. self.lora_request = lora_request
  358. self.prompt_logprobs: Optional[PromptLogprobs] = None
  359. self.state = SequenceGroupState()
  360. self.embeddings = embeddings
  361. self.pooling_params = pooling_params
  362. self.encoder_seq = encoder_seq
  363. @property
  364. def prompt(self) -> Optional[str]:
  365. # All sequences in the group should have the same prompt.
  366. # We use the prompt of an arbitrary sequence.
  367. return next(iter(self.seqs_dict.values())).prompt
  368. @property
  369. def prompt_token_ids(self) -> List[int]:
  370. # All sequences in the group should have the same prompt.
  371. # We use the prompt of an arbitrary sequence.
  372. return next(iter(self.seqs_dict.values())).prompt_token_ids
  373. @property
  374. def multi_modal_data(self) -> Optional["MultiModalData"]:
  375. # All sequences in the group should have the same multi-modal data.
  376. # We use the multi-modal data of an arbitrary sequence.
  377. return next(iter(self.seqs_dict.values())).multi_modal_data
  378. @property
  379. def lora_int_id(self) -> int:
  380. return self.lora_request.lora_int_id if self.lora_request else 0
  381. def get_last_latency(self, now: float) -> Optional[float]:
  382. """Sets the last token time for Request level timings."""
  383. # If still in prefill phase, raise Error.
  384. if self.is_prefill():
  385. raise ValueError(
  386. "seq_group.get_last_latency() should not be called "
  387. "if the seq_group is in prefill phase.")
  388. # Otherwise return token latency.
  389. latency = now - self.metrics.last_token_time
  390. self.metrics.last_token_time = now
  391. return latency
  392. def maybe_set_first_token_time(self, time: float) -> None:
  393. """Sets the first token time for Request level timings."""
  394. # NOTE: in a case where a sequence_group is swapped and
  395. # recomputed, the time between iterations is counted
  396. # in TPOT, rather than recalculating TTFT (since from the )
  397. # POV of the user, there is simply a long generation delay.
  398. if (self.metrics.first_token_time is None
  399. and self.get_seqs()[0].get_output_len() == 1):
  400. self.metrics.first_token_time = time
  401. def maybe_set_first_scheduled_time(self, time: float) -> None:
  402. """Sets the first scheduled time and time in queue for Request
  403. level timings."""
  404. if self.metrics.first_scheduled_time is None:
  405. self.metrics.first_scheduled_time = time
  406. self.metrics.time_in_queue = time - self.metrics.arrival_time
  407. def set_finished_time(self, time: Optional[float]) -> None:
  408. """Sets the finished time for Request level timings."""
  409. self.metrics.finished_time = time
  410. def get_max_num_running_seqs(self) -> int:
  411. """The maximum number of sequences running in parallel in the remaining
  412. lifetime of the request."""
  413. if self.sampling_params and self.sampling_params.use_beam_search:
  414. # For beam search, maximally there will always be `best_of` beam
  415. # candidates running in the future.
  416. return self.sampling_params.best_of
  417. else:
  418. if (self.sampling_params
  419. and self.sampling_params.best_of > self.num_seqs()):
  420. # At prompt stage, the sequence group is not yet filled up
  421. # and only have one sequence running. However, in the
  422. # generation stage, we will have `best_of` sequences running.
  423. return self.sampling_params.best_of
  424. # At sampling stages, return the number of actual sequences
  425. # that are not finished yet.
  426. return self.num_unfinished_seqs()
  427. def get_seqs(
  428. self,
  429. status: Optional[SequenceStatus] = None,
  430. ) -> List[Sequence]:
  431. return list(self.seqs_dict.values()) if status is None else [
  432. seq for seq in self.seqs_dict.values() if seq.status == status
  433. ]
  434. def get_unfinished_seqs(self) -> List[Sequence]:
  435. return [
  436. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  437. ]
  438. def is_encoder_decoder(self) -> bool:
  439. return self.encoder_seq is not None
  440. def get_encoder_seq(self) -> Optional[Sequence]:
  441. return self.encoder_seq
  442. def get_finished_seqs(self) -> List[Sequence]:
  443. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  444. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  445. """Update number of tokens computed so far."""
  446. for seq in self.seqs_dict.values():
  447. if not seq.is_finished():
  448. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  449. def get_num_uncomputed_tokens(self) -> int:
  450. num_uncomputed_tokens = 0
  451. for seq in self.get_seqs():
  452. if not seq.is_finished():
  453. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  454. return num_uncomputed_tokens
  455. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  456. # Optimization. We don't need to call get_seqs if we don't need to
  457. # filter by states.
  458. if status is None:
  459. return len(self.seqs_dict)
  460. return len(self.get_seqs(status))
  461. def num_unfinished_seqs(self) -> int:
  462. return len(self.get_unfinished_seqs())
  463. def num_finished_seqs(self) -> int:
  464. return len(self.get_finished_seqs())
  465. def find(self, seq_id: int) -> Sequence:
  466. if seq_id not in self.seqs_dict:
  467. raise ValueError(f"Sequence {seq_id} not found.")
  468. return self.seqs_dict[seq_id]
  469. def add(self, seq: Sequence) -> None:
  470. if seq.seq_id in self.seqs_dict:
  471. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  472. self.seqs_dict[seq.seq_id] = seq
  473. def remove(self, seq_id: int) -> None:
  474. if seq_id not in self.seqs_dict:
  475. raise ValueError(f"Sequence {seq_id} not found.")
  476. del self.seqs_dict[seq_id]
  477. def is_finished(self) -> bool:
  478. return all(seq.is_finished() for seq in self.get_seqs())
  479. def is_prefill(self) -> bool:
  480. # Every sequence should be in the same stage.
  481. return self.get_seqs()[0].is_prefill()
  482. def __repr__(self) -> str:
  483. return (f"SequenceGroup(request_id={self.request_id}, "
  484. f"sampling_params={self.sampling_params}, "
  485. f"num_seqs={len(self.seqs_dict)})")
  486. class SequenceGroupMetadata:
  487. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  488. Args:
  489. request_id: The ID of the request.
  490. is_prompt: Whether the request is at prompt stage.
  491. seq_data: The sequence data. (Seq id -> sequence data)
  492. sampling_params: The sampling parameters used to generate the outputs.
  493. block_tables: The block tables. (Seq id -> list of physical block
  494. numbers)
  495. do_sample: True if sampling is required. Sampling is not required when
  496. e.g., prefill is chunked, and the current iteration only computes
  497. query tokens for prefill, we don't need sampling.
  498. pooling_params: The pooling parameters used to generate the pooling.
  499. token_chunk_size: The number of tokens to be processed (per sequence).
  500. None if chunking is not required.
  501. lora_request: LoRA request.
  502. computed_block_nums: The block numbers that are already computed,
  503. used in prefix caching.
  504. state: Internal state tied to this sequence group.
  505. multi_modal_data: Multi modal data.
  506. encoder_seq_data: Optional sequence data for encoder prompt
  507. (SequenceGroup.encoder_seq). Should be None
  508. unless you are working with an encoder/decoder
  509. model.
  510. cross_block_table: Optional cross-attention block table associated
  511. with the encoder prompt
  512. (SequenceGroup.encoder_seq). Should be None
  513. unless you are working with an encoder/decoder
  514. model.
  515. """
  516. def __init__(
  517. self,
  518. request_id: str,
  519. is_prompt: bool,
  520. seq_data: Dict[int, SequenceData],
  521. sampling_params: SamplingParams,
  522. block_tables: Dict[int, List[int]],
  523. do_sample: bool = True,
  524. pooling_params: Optional[PoolingParams] = None,
  525. token_chunk_size: Optional[int] = None,
  526. lora_request: Optional[LoRARequest] = None,
  527. computed_block_nums: Optional[List[int]] = None,
  528. state: Optional[SequenceGroupState] = None,
  529. multi_modal_data: Optional["MultiModalData"] = None,
  530. encoder_seq_data: Optional[SequenceData] = None,
  531. cross_block_table: Optional[List[int]] = None,
  532. ) -> None:
  533. self.request_id = request_id
  534. self.is_prompt = is_prompt
  535. self.seq_data = seq_data
  536. self.sampling_params = sampling_params
  537. self.block_tables = block_tables
  538. self.pooling_params = pooling_params
  539. self.lora_request = lora_request
  540. self.computed_block_nums = computed_block_nums
  541. self.multi_modal_data = multi_modal_data
  542. self.state = SequenceGroupState() if state is None else state
  543. self.encoder_seq_data = encoder_seq_data
  544. self.cross_block_table = cross_block_table
  545. self._token_chunk_size = token_chunk_size
  546. self.do_sample = do_sample
  547. # The number of speculative tokens adopted in this request.
  548. # None means specuative decoding is not used.
  549. # Zero means speculative decoding is disabled for some reasons.
  550. # TODO: We should maintain this states out of the sequence group.
  551. self.num_speculative_tokens = None
  552. if self._token_chunk_size is None:
  553. if is_prompt:
  554. self._token_chunk_size = list(seq_data.values())[0].get_len()
  555. else:
  556. self._token_chunk_size = 1
  557. @property
  558. def lora_int_id(self) -> int:
  559. return self.lora_request.lora_int_id if self.lora_request else 0
  560. @property
  561. def token_chunk_size(self) -> int:
  562. """Return the number of tokens to be processed (chunk size)."""
  563. assert self._token_chunk_size is not None
  564. return self._token_chunk_size
  565. class SequenceOutput:
  566. """The model output associated with a sequence.
  567. Args:
  568. parent_seq_id: The ID of the parent sequence (for forking in beam
  569. search).
  570. output_token: The output token ID.
  571. logprobs: The logprobs of the output token.
  572. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  573. """
  574. def __init__(
  575. self,
  576. parent_seq_id: int,
  577. output_token: int,
  578. logprobs: Dict[int, Logprob],
  579. ) -> None:
  580. self.parent_seq_id = parent_seq_id
  581. self.output_token = output_token
  582. self.logprobs = logprobs
  583. def __repr__(self) -> str:
  584. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  585. f"output_token={self.output_token}, "
  586. f"logprobs={self.logprobs})")
  587. def __eq__(self, other: object) -> bool:
  588. if not isinstance(other, SequenceOutput):
  589. raise NotImplementedError()
  590. equal = (self.parent_seq_id == other.parent_seq_id
  591. and self.output_token == other.output_token)
  592. log_probs_equal = other.logprobs == self.logprobs
  593. return equal and log_probs_equal
  594. class SequenceGroupOutput(ABC):
  595. """The base class for model outputs associated with a sequence group."""
  596. @abstractmethod
  597. def __repr__(self) -> str:
  598. pass
  599. @abstractmethod
  600. def __eq__(self, other: object) -> bool:
  601. pass
  602. class CompletionSequenceGroupOutput(SequenceGroupOutput):
  603. """The model output associated with a completion sequence group."""
  604. def __init__(
  605. self,
  606. samples: List[SequenceOutput],
  607. prompt_logprobs: Optional[PromptLogprobs],
  608. ) -> None:
  609. self.samples = samples
  610. # Prompt logprob for each prompt query token.
  611. self.prompt_logprobs = prompt_logprobs
  612. def __repr__(self) -> str:
  613. return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
  614. f"prompt_logprobs={self.prompt_logprobs})")
  615. def __eq__(self, other: object) -> bool:
  616. if not isinstance(other, CompletionSequenceGroupOutput):
  617. raise NotImplementedError()
  618. return (self.samples == other.samples
  619. and self.prompt_logprobs == other.prompt_logprobs)
  620. class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
  621. """The model output associated with an embedding sequence group."""
  622. def __init__(
  623. self,
  624. embeddings: List[float],
  625. ) -> None:
  626. self.embeddings = embeddings
  627. def __repr__(self) -> str:
  628. return (f"EmbeddingSequenceGroupOutput("
  629. f"embeddings_shape={len(self.embeddings)})")
  630. def __eq__(self, other: object) -> bool:
  631. if not isinstance(other, EmbeddingSequenceGroupOutput):
  632. raise NotImplementedError()
  633. return self.embeddings == other.embeddings
  634. @dataclass
  635. class SamplerOutput:
  636. """For each sequence group, we generate a list of SequenceOutput object,
  637. each of which contains one possible candidate for the next token.
  638. This data structure implements methods, so it can be used like a list, but
  639. also has optional fields for device tensors.
  640. """
  641. outputs: List[CompletionSequenceGroupOutput]
  642. # On-device tensor containing probabilities of each token.
  643. sampled_token_probs: Optional[torch.Tensor] = None
  644. # On-device tensor containing the logprobs of each token.
  645. logprobs: Optional[torch.Tensor] = None
  646. # On-device tensor containing the sampled token ids.
  647. sampled_token_ids: Optional[torch.Tensor] = None
  648. # Spec decode metrics populated by workers.
  649. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  650. def __getitem__(self, idx: int):
  651. return self.outputs[idx]
  652. def __setitem__(self, idx: int, value):
  653. self.outputs[idx] = value
  654. def __len__(self):
  655. return len(self.outputs)
  656. def __eq__(self, other: object):
  657. return isinstance(other,
  658. self.__class__) and self.outputs == other.outputs
  659. def __repr__(self) -> str:
  660. """Show the shape of a tensor instead of its values to reduce noise.
  661. """
  662. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  663. else self.sampled_token_probs.shape)
  664. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  665. self.sampled_token_ids.shape)
  666. return (
  667. f"SamplerOutput(outputs={self.outputs}, "
  668. f"sampled_token_probs={sampled_token_probs_repr}, "
  669. f"sampled_token_ids={sampled_token_ids_repr}, "
  670. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
  671. @dataclass
  672. class PoolerOutput:
  673. """The output from a pooling operation in the embedding model."""
  674. outputs: List[EmbeddingSequenceGroupOutput]
  675. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  676. def __getitem__(self, idx: int):
  677. return self.outputs[idx]
  678. def __setitem__(self, idx: int, value):
  679. self.outputs[idx] = value
  680. def __len__(self):
  681. return len(self.outputs)
  682. def __eq__(self, other: object):
  683. return isinstance(other,
  684. self.__class__) and self.outputs == other.outputs
  685. @dataclass
  686. class ExecuteModelRequest:
  687. """The model execution request."""
  688. # The sequence group metadata list.
  689. seq_group_metadata_list: List[SequenceGroupMetadata]
  690. # Blocks to swap in. List of CPU -> GPU block number.
  691. blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
  692. # Blocks to swap out. List of GPU -> CPU block number.
  693. blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
  694. # Blocks to copy. Source to dest block.
  695. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
  696. # The number of slots for lookahead decoding.
  697. num_lookahead_slots: int = 0
  698. # The number of requests in the running queue.
  699. running_queue_size: int = 0
  700. def clone(
  701. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  702. ) -> "ExecuteModelRequest":
  703. """Clone the request with a new sequence group metadata list."""
  704. return ExecuteModelRequest(
  705. seq_group_metadata_list=seq_group_metadata_list,
  706. blocks_to_swap_in=self.blocks_to_swap_in.copy(),
  707. blocks_to_swap_out=self.blocks_to_swap_out.copy(),
  708. blocks_to_copy=self.blocks_to_copy.copy(),
  709. num_lookahead_slots=self.num_lookahead_slots,
  710. running_queue_size=self.running_queue_size,
  711. )