sequence.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. import math
  5. from abc import ABC, abstractmethod
  6. from dataclasses import dataclass, field
  7. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
  8. import torch
  9. from aphrodite.common.pooling_params import PoolingParams
  10. from aphrodite.common.sampling_params import SamplingParams
  11. from aphrodite.lora.request import LoRARequest
  12. if TYPE_CHECKING:
  13. from aphrodite.inputs import LLMInputs
  14. from aphrodite.multimodal import MultiModalDataDict
  15. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  16. @dataclass
  17. class Logprob:
  18. """Infos for supporting OpenAI compatible logprobs and token ranks.
  19. Attributes:
  20. logprob: The logprob of chosen token
  21. rank: The vocab rank of chosen token (>=1)
  22. decoded_token: The decoded chosen token index
  23. """
  24. logprob: float
  25. rank: Optional[int] = None
  26. decoded_token: Optional[str] = None
  27. # {token_id -> logprob} per each sequence group. None if the corresponding
  28. # sequence group doesn't require prompt logprob.
  29. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  30. # {token_id -> logprob} for each sequence group.
  31. SampleLogprobs = List[Dict[int, Logprob]]
  32. class SequenceStatus(enum.IntEnum):
  33. """Status of a sequence."""
  34. WAITING = 0
  35. RUNNING = 1
  36. SWAPPED = 2
  37. # Note: anything after SWAPPED (2) will be considered
  38. # as a finished status.
  39. FINISHED_STOPPED = 3
  40. FINISHED_LENGTH_CAPPED = 4
  41. FINISHED_ABORTED = 5
  42. FINISHED_IGNORED = 6
  43. @staticmethod
  44. def is_finished(status: "SequenceStatus") -> bool:
  45. return status > SequenceStatus.SWAPPED
  46. @staticmethod
  47. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  48. if status == SequenceStatus.FINISHED_STOPPED:
  49. finish_reason = "stop"
  50. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  51. finish_reason = "length"
  52. elif status == SequenceStatus.FINISHED_ABORTED:
  53. finish_reason = "abort"
  54. elif status == SequenceStatus.FINISHED_IGNORED:
  55. # The ignored sequences are the sequences whose prompt lengths
  56. # are longer than the model's length cap. Therefore, the stop
  57. # reason should also be "length" as in OpenAI API.
  58. finish_reason = "length"
  59. else:
  60. finish_reason = None
  61. return finish_reason
  62. class SequenceStage(enum.Enum):
  63. PREFILL = enum.auto()
  64. DECODE = enum.auto()
  65. @dataclass
  66. class RequestMetrics:
  67. """Metrics associated with a request.
  68. Attributes:
  69. arrival_time: The time when the request arrived.
  70. first_scheduled_time: The time when the request was first scheduled.
  71. first_token_time: The time when the first token was generated.
  72. time_in_queue: The time the request spent in the queue.
  73. finished_time: The time when the request was finished.
  74. """
  75. arrival_time: float
  76. last_token_time: float
  77. first_scheduled_time: Optional[float]
  78. first_token_time: Optional[float]
  79. time_in_queue: Optional[float]
  80. finished_time: Optional[float] = None
  81. class SequenceData:
  82. """Data associated with a sequence.
  83. Args:
  84. prompt_token_ids: The token IDs of the prompt.
  85. output_token_ids: The token IDs of the output. Set to an empty list if
  86. None.
  87. Attributes:
  88. prompt_token_ids: The token IDs of the prompt.
  89. output_token_ids: The token IDs of the output.
  90. cumulative_logprob: The cumulative log probability of the output.
  91. """
  92. def __init__(
  93. self,
  94. prompt_token_ids: List[int],
  95. output_token_ids: Optional[List[int]] = None,
  96. ) -> None:
  97. self._prompt_token_ids: List[int] = list(prompt_token_ids)
  98. self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
  99. self._output_token_ids: List[int] = (
  100. list(output_token_ids) if output_token_ids is not None else [])
  101. self.cumulative_logprob = 0.0
  102. # The number of tokens that are computed (that run against the model).
  103. self._num_computed_tokens = 0
  104. self._stage: SequenceStage = SequenceStage.PREFILL
  105. self._update_cached_all_tokens()
  106. def _update_cached_all_tokens(self):
  107. self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
  108. self._output_token_ids)
  109. @property
  110. def prompt_token_ids(self) -> Tuple[int, ...]:
  111. return self._prompt_token_ids_tuple
  112. @prompt_token_ids.setter
  113. def prompt_token_ids(self, new_prompt_token_ids) -> None:
  114. self._prompt_token_ids = list(new_prompt_token_ids)
  115. self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
  116. self._update_cached_all_tokens()
  117. @property
  118. def output_token_ids(self) -> Tuple[int, ...]:
  119. return tuple(self._output_token_ids)
  120. @output_token_ids.setter
  121. def output_token_ids(self, new_output_token_ids) -> None:
  122. self._output_token_ids = list(new_output_token_ids)
  123. self._update_cached_all_tokens()
  124. def append_token_id(self, token_id: int, logprob: float) -> None:
  125. self._output_token_ids.append(token_id)
  126. self._cached_all_token_ids.append(token_id)
  127. self.cumulative_logprob += logprob
  128. def get_len(self) -> int:
  129. return len(self._output_token_ids) + len(self._prompt_token_ids)
  130. def get_prompt_len(self) -> int:
  131. return len(self._prompt_token_ids)
  132. def get_output_len(self) -> int:
  133. return len(self._output_token_ids)
  134. def get_token_ids(self) -> List[int]:
  135. return self._cached_all_token_ids
  136. def get_prefix_token_ids(
  137. self, num_tokens: int
  138. ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
  139. """Get prefix tokens, and make the return value hashable"""
  140. prompt_length = self.get_prompt_len()
  141. if num_tokens > prompt_length:
  142. return (self._prompt_token_ids_tuple,
  143. tuple(self._output_token_ids[:num_tokens - prompt_length]))
  144. else:
  145. return (self._prompt_token_ids_tuple[:num_tokens], None)
  146. def get_num_computed_tokens(self) -> int:
  147. """Return the number of prefill tokens that are already computed."""
  148. return self._num_computed_tokens
  149. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  150. """Update number of tokens computed so far."""
  151. self._num_computed_tokens += num_new_computed_tokens
  152. assert self._num_computed_tokens <= self.get_len(), (
  153. self._num_computed_tokens, self.get_len())
  154. # If all tokens are computed, it means it is in decoding phase.
  155. if self.get_num_uncomputed_tokens() == 0:
  156. self._stage = SequenceStage.DECODE
  157. def reset_state_for_recompute(self) -> None:
  158. """Reset the number of computed tokens from this sequence. It is
  159. supposed to be called when a sequence needs to be started from
  160. the beginning again (e.g., sequence is preempted).
  161. """
  162. self._num_computed_tokens = 0
  163. self._stage = SequenceStage.PREFILL
  164. def get_num_uncomputed_tokens(self) -> int:
  165. """Return the number of prefill tokens that are not computed."""
  166. # we use `get_len()` which includes prompt_len + output_len instead
  167. # of prompt_len here. This is because during recompute we need to
  168. # prefill for both prompt and output.
  169. return self.get_len() - self.get_num_computed_tokens()
  170. def get_last_token_id(self) -> int:
  171. if not self._output_token_ids:
  172. return self._prompt_token_ids[-1]
  173. return self._output_token_ids[-1]
  174. def get_prompt_token_ids(self) -> Tuple[int, ...]:
  175. return self.prompt_token_ids
  176. def get_output_token_ids(self) -> Tuple[int, ...]:
  177. return self.output_token_ids
  178. @property
  179. def stage(self) -> SequenceStage:
  180. return self._stage
  181. def __repr__(self) -> str:
  182. return (f"SequenceData("
  183. f"prompt_token_ids={self._prompt_token_ids}, "
  184. f"output_token_ids={self._output_token_ids}, "
  185. f"cumulative_logprob={self.cumulative_logprob})")
  186. class Sequence:
  187. """Stores the data, status, and block information of a sequence.
  188. Args:
  189. seq_id: The ID of the sequence.
  190. inputs: The inputs of the sequence.
  191. block_size: The block size of the sequence. Should be the same as the
  192. block size used by the block manager and cache engine.
  193. lora_request: LoRA request.
  194. """
  195. def __init__(
  196. self,
  197. seq_id: int,
  198. inputs: "LLMInputs",
  199. block_size: int,
  200. eos_token_id: Optional[int] = None,
  201. lora_request: Optional[LoRARequest] = None,
  202. ) -> None:
  203. self.seq_id = seq_id
  204. self.inputs = inputs
  205. self.block_size = block_size
  206. self.eos_token_id = eos_token_id
  207. self.lora_request = lora_request
  208. self.data = SequenceData(self.prompt_token_ids)
  209. self.output_logprobs: SampleLogprobs = []
  210. self.output_text = ""
  211. self.status = SequenceStatus.WAITING
  212. self.stop_reason: Union[int, str, None] = None
  213. # Used for incremental detokenization
  214. self.prefix_offset = 0
  215. self.read_offset = 0
  216. # Input + output tokens
  217. self.tokens: Optional[List[str]] = None
  218. @property
  219. def n_blocks(self) -> int:
  220. return math.ceil(self.get_len() / self.block_size)
  221. @property
  222. def prompt(self) -> Optional[str]:
  223. return self.inputs.get("prompt")
  224. @property
  225. def prompt_token_ids(self) -> List[int]:
  226. return self.inputs["prompt_token_ids"]
  227. @property
  228. def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
  229. return self.inputs.get("multi_modal_data")
  230. @property
  231. def lora_int_id(self) -> int:
  232. return self.lora_request.lora_int_id if self.lora_request else 0
  233. def get_output_text_to_return(self, buffer_length: int):
  234. # We return the full output text if the sequence is finished.
  235. truncate = buffer_length and not self.is_finished()
  236. return self.output_text[:-buffer_length] if truncate else (
  237. self.output_text)
  238. def hash_of_block(self, logical_idx: int) -> int:
  239. # TODO This can produce incorrect hash when block size > prompt size
  240. # Compute the number of tokens in the sequence
  241. # TODO: The current hashing function is O(L^2). We should optimize
  242. # this in the future.
  243. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  244. hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
  245. return hash((hashed_tokens, self.lora_int_id))
  246. def num_hashed_tokens_of_block(self, logical_idx: int):
  247. return logical_idx * self.block_size + self.block_size
  248. def reset_state_for_recompute(self):
  249. """Reset the sequence states for recomputation."""
  250. self.data.reset_state_for_recompute()
  251. def append_token_id(
  252. self,
  253. token_id: int,
  254. logprobs: Dict[int, Logprob],
  255. ) -> None:
  256. assert token_id in logprobs
  257. self.output_logprobs.append(logprobs)
  258. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  259. def get_len(self) -> int:
  260. return self.data.get_len()
  261. def get_prompt_len(self) -> int:
  262. return self.data.get_prompt_len()
  263. def get_output_len(self) -> int:
  264. return self.data.get_output_len()
  265. def get_token_ids(self) -> List[int]:
  266. return self.data.get_token_ids()
  267. def get_prompt_token_ids(self) -> Tuple[int, ...]:
  268. return self.data.get_prompt_token_ids()
  269. def get_last_token_id(self) -> int:
  270. return self.data.get_last_token_id()
  271. def get_output_token_ids(self) -> Tuple[int, ...]:
  272. return self.data.get_output_token_ids()
  273. def get_cumulative_logprob(self) -> float:
  274. return self.data.cumulative_logprob
  275. def get_beam_search_score(self,
  276. length_penalty: float = 1.0,
  277. seq_len: Optional[int] = None,
  278. eos_token_id: Optional[int] = None) -> float:
  279. """Calculate the beam search score with length penalty.
  280. Adapted from
  281. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  282. """
  283. if seq_len is None:
  284. seq_len = self.get_len()
  285. # NOTE: HF implementation does not count the EOS token
  286. # towards the length, we align with that here for testing.
  287. if (eos_token_id is not None
  288. and self.get_last_token_id() == eos_token_id):
  289. seq_len -= 1
  290. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  291. def is_finished(self) -> bool:
  292. return SequenceStatus.is_finished(self.status)
  293. def fork(self, new_seq_id: int) -> "Sequence":
  294. new_seq = copy.deepcopy(self)
  295. new_seq.seq_id = new_seq_id
  296. return new_seq
  297. def get_num_new_tokens(self) -> int:
  298. """Get the number of new tokens to be computed.
  299. Returns:
  300. The new number of tokens to be computed. I.e., 1 for decode, or
  301. the remaining prompt size for prefill.
  302. """
  303. if self.data.stage == SequenceStage.DECODE:
  304. return 1
  305. return self.data.get_num_uncomputed_tokens()
  306. def is_prefill(self) -> bool:
  307. return self.data.stage == SequenceStage.PREFILL
  308. def __repr__(self) -> str:
  309. return (f"Sequence(seq_id={self.seq_id}, "
  310. f"status={self.status.name}, "
  311. f"num_blocks={self.n_blocks}, ")
  312. @dataclass
  313. class SequenceGroupState:
  314. """Mutable state tied to a specific sequence group"""
  315. # torch.Generator used in seeded sampling
  316. generator: Optional = None # type: ignore
  317. class SequenceGroup:
  318. """A group of sequences that are generated from the same prompt.
  319. Args:
  320. request_id: The ID of the request.
  321. seqs: The list of sequences.
  322. sampling_params: The sampling parameters used to generate the outputs.
  323. arrival_time: The arrival time of the request.
  324. lora_request: LoRA request.
  325. embeddings: The embeddings vectors of the prompt of the sequence group
  326. for an embedding model.
  327. pooling_params: The pooling parameters used to generate the pooling
  328. for an embedding model.
  329. encoder_seq: Optional, the single encoder sequence. Should be None
  330. unless you are working with an encoder/decoder model.
  331. trace_headers: OpenTelemetry trace headers.
  332. """
  333. def __init__(
  334. self,
  335. request_id: str,
  336. seqs: List[Sequence],
  337. arrival_time: float,
  338. sampling_params: Optional[SamplingParams] = None,
  339. lora_request: Optional[LoRARequest] = None,
  340. embeddings: Optional[List[float]] = None,
  341. pooling_params: Optional[PoolingParams] = None,
  342. encoder_seq: Optional[Sequence] = None,
  343. trace_headers: Optional[Dict[str, str]] = None,
  344. ) -> None:
  345. self.request_id = request_id
  346. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  347. self.sampling_params = sampling_params
  348. self.metrics = RequestMetrics(arrival_time=arrival_time,
  349. last_token_time=arrival_time,
  350. first_scheduled_time=None,
  351. first_token_time=None,
  352. time_in_queue=None)
  353. self.lora_request = lora_request
  354. self.prompt_logprobs: Optional[PromptLogprobs] = None
  355. self.state = SequenceGroupState()
  356. self.embeddings = embeddings
  357. self.pooling_params = pooling_params
  358. self.encoder_seq = encoder_seq
  359. self.trace_headers = trace_headers
  360. @property
  361. def prompt(self) -> Optional[str]:
  362. # All sequences in the group should have the same prompt.
  363. # We use the prompt of an arbitrary sequence.
  364. return next(iter(self.seqs_dict.values())).prompt
  365. @property
  366. def prompt_token_ids(self) -> List[int]:
  367. # All sequences in the group should have the same prompt.
  368. # We use the prompt of an arbitrary sequence.
  369. return next(iter(self.seqs_dict.values())).prompt_token_ids
  370. @property
  371. def multi_modal_data(self) -> "MultiModalDataDict":
  372. # All sequences in the group should have the same multi-modal data.
  373. # We use the multi-modal data of an arbitrary sequence.
  374. return next(iter(self.seqs_dict.values())).multi_modal_data
  375. @property
  376. def lora_int_id(self) -> int:
  377. return self.lora_request.lora_int_id if self.lora_request else 0
  378. def get_last_latency(self, now: float) -> Optional[float]:
  379. """Sets the last token time for Request level timings."""
  380. # If still in prefill phase, raise Error.
  381. if self.is_prefill():
  382. raise ValueError(
  383. "seq_group.get_last_latency() should not be called "
  384. "if the seq_group is in prefill phase.")
  385. # Otherwise return token latency.
  386. latency = now - self.metrics.last_token_time
  387. self.metrics.last_token_time = now
  388. return latency
  389. def maybe_set_first_token_time(self, time: float) -> None:
  390. """Sets the first token time for Request level timings."""
  391. # NOTE: in a case where a sequence_group is swapped and
  392. # recomputed, the time between iterations is counted
  393. # in TPOT, rather than recalculating TTFT (since from the )
  394. # POV of the user, there is simply a long generation delay.
  395. if (self.metrics.first_token_time is None
  396. and self.get_seqs()[0].get_output_len() == 1):
  397. self.metrics.first_token_time = time
  398. def maybe_set_first_scheduled_time(self, time: float) -> None:
  399. """Sets the first scheduled time and time in queue for Request
  400. level timings."""
  401. if self.metrics.first_scheduled_time is None:
  402. self.metrics.first_scheduled_time = time
  403. self.metrics.time_in_queue = time - self.metrics.arrival_time
  404. def set_finished_time(self, time: Optional[float]) -> None:
  405. """Sets the finished time for Request level timings."""
  406. self.metrics.finished_time = time
  407. def get_max_num_running_seqs(self) -> int:
  408. """The maximum number of sequences running in parallel in the remaining
  409. lifetime of the request."""
  410. if self.sampling_params and self.sampling_params.use_beam_search:
  411. # For beam search, maximally there will always be `best_of` beam
  412. # candidates running in the future.
  413. return self.sampling_params.best_of
  414. else:
  415. if (self.sampling_params
  416. and self.sampling_params.best_of > self.num_seqs()):
  417. # At prompt stage, the sequence group is not yet filled up
  418. # and only have one sequence running. However, in the
  419. # generation stage, we will have `best_of` sequences running.
  420. return self.sampling_params.best_of
  421. # At sampling stages, return the number of actual sequences
  422. # that are not finished yet.
  423. return self.num_unfinished_seqs()
  424. def get_seqs(
  425. self,
  426. status: Optional[SequenceStatus] = None,
  427. ) -> List[Sequence]:
  428. return list(self.seqs_dict.values()) if status is None else [
  429. seq for seq in self.seqs_dict.values() if seq.status == status
  430. ]
  431. def is_encoder_decoder(self) -> bool:
  432. return self.encoder_seq is not None
  433. def get_encoder_seq(self) -> Optional[Sequence]:
  434. return self.encoder_seq
  435. def get_unfinished_seqs(self) -> List[Sequence]:
  436. return [
  437. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  438. ]
  439. def get_finished_seqs(self) -> List[Sequence]:
  440. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  441. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  442. """Update number of tokens computed so far."""
  443. for seq in self.seqs_dict.values():
  444. if not seq.is_finished():
  445. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  446. def get_num_uncomputed_tokens(self) -> int:
  447. num_uncomputed_tokens = 0
  448. for seq in self.get_seqs():
  449. if not seq.is_finished():
  450. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  451. return num_uncomputed_tokens
  452. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  453. # Optimization. We don't need to call get_seqs if we don't need to
  454. # filter by states.
  455. if status is None:
  456. return len(self.seqs_dict)
  457. return len(self.get_seqs(status))
  458. def num_unfinished_seqs(self) -> int:
  459. return len(self.get_unfinished_seqs())
  460. def num_finished_seqs(self) -> int:
  461. return len(self.get_finished_seqs())
  462. def find(self, seq_id: int) -> Sequence:
  463. if seq_id not in self.seqs_dict:
  464. raise ValueError(f"Sequence {seq_id} not found.")
  465. return self.seqs_dict[seq_id]
  466. def add(self, seq: Sequence) -> None:
  467. if seq.seq_id in self.seqs_dict:
  468. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  469. self.seqs_dict[seq.seq_id] = seq
  470. def remove(self, seq_id: int) -> None:
  471. if seq_id not in self.seqs_dict:
  472. raise ValueError(f"Sequence {seq_id} not found.")
  473. del self.seqs_dict[seq_id]
  474. def is_finished(self) -> bool:
  475. return all(seq.is_finished() for seq in self.get_seqs())
  476. def is_prefill(self) -> bool:
  477. # Every sequence should be in the same stage.
  478. return self.get_seqs()[0].is_prefill()
  479. def __repr__(self) -> str:
  480. return (f"SequenceGroup(request_id={self.request_id}, "
  481. f"sampling_params={self.sampling_params}, "
  482. f"num_seqs={len(self.seqs_dict)})")
  483. class SequenceGroupMetadata:
  484. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  485. Args:
  486. request_id: The ID of the request.
  487. is_prompt: Whether the request is at prompt stage.
  488. seq_data: The sequence data. (Seq id -> sequence data)
  489. sampling_params: The sampling parameters used to generate the outputs.
  490. block_tables: The block tables. (Seq id -> list of physical block
  491. numbers)
  492. do_sample: True if sampling is required. Sampling is not required when
  493. e.g., prefill is chunked, and the current iteration only computes
  494. query tokens for prefill, we don't need sampling.
  495. token_chunk_size: The number of tokens to be processed (per sequence).
  496. None if chunking is not required.
  497. lora_request: LoRA request.
  498. computed_block_nums: The block numbers that are already computed,
  499. used in prefix caching.
  500. state: Internal state tied to this sequence group.
  501. multi_modal_data: Multi modal data.
  502. encoder_seq_data: Optional sequence data for encoder prompt
  503. (SequenceGroup.encoder_seq). Should be None
  504. unless you are working with an encoder/decoder
  505. model.
  506. cross_block_table: Optional cross-attention block table associated
  507. with the encoder prompt
  508. (SequenceGroup.encoder_seq). Should be None
  509. unless you are working with an encoder/decoder
  510. model.
  511. """
  512. def __init__(
  513. self,
  514. request_id: str,
  515. is_prompt: bool,
  516. seq_data: Dict[int, SequenceData],
  517. sampling_params: SamplingParams,
  518. block_tables: Dict[int, List[int]],
  519. do_sample: bool = True,
  520. pooling_params: Optional[PoolingParams] = None,
  521. token_chunk_size: Optional[int] = None,
  522. lora_request: Optional[LoRARequest] = None,
  523. computed_block_nums: Optional[List[int]] = None,
  524. state: Optional[SequenceGroupState] = None,
  525. multi_modal_data: Optional["MultiModalDataDict"] = None,
  526. encoder_seq_data: Optional[SequenceData] = None,
  527. cross_block_table: Optional[List[int]] = None,
  528. ) -> None:
  529. self.request_id = request_id
  530. self.is_prompt = is_prompt
  531. self.seq_data = seq_data
  532. self.sampling_params = sampling_params
  533. self.block_tables = block_tables
  534. self.pooling_params = pooling_params
  535. self.lora_request = lora_request
  536. self.computed_block_nums = computed_block_nums
  537. self.multi_modal_data = multi_modal_data
  538. self.state = SequenceGroupState() if state is None else state
  539. self.encoder_seq_data = encoder_seq_data
  540. self.cross_block_table = cross_block_table
  541. self._token_chunk_size = token_chunk_size
  542. self.do_sample = do_sample
  543. # The number of speculative tokens adopted in this request.
  544. # None means specuative decoding is not used.
  545. # Zero means speculative decoding is disabled for some reasons.
  546. # TODO: We should maintain this states out of the sequence group.
  547. self.num_speculative_tokens = None
  548. if self._token_chunk_size is None:
  549. if is_prompt:
  550. self._token_chunk_size = list(seq_data.values())[0].get_len()
  551. else:
  552. self._token_chunk_size = 1
  553. @property
  554. def lora_int_id(self) -> int:
  555. return self.lora_request.lora_int_id if self.lora_request else 0
  556. @property
  557. def token_chunk_size(self) -> int:
  558. """Return the number of tokens to be processed (chunk size)."""
  559. assert self._token_chunk_size is not None
  560. return self._token_chunk_size
  561. class SequenceOutput:
  562. """The model output associated with a sequence.
  563. Args:
  564. parent_seq_id: The ID of the parent sequence (for forking in beam
  565. search).
  566. output_token: The output token ID.
  567. logprobs: The logprobs of the output token.
  568. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  569. """
  570. def __init__(
  571. self,
  572. parent_seq_id: int,
  573. output_token: int,
  574. logprobs: Dict[int, Logprob],
  575. ) -> None:
  576. self.parent_seq_id = parent_seq_id
  577. self.output_token = output_token
  578. self.logprobs = logprobs
  579. def __repr__(self) -> str:
  580. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  581. f"output_token={self.output_token}, "
  582. f"logprobs={self.logprobs})")
  583. def __eq__(self, other: object) -> bool:
  584. if not isinstance(other, SequenceOutput):
  585. raise NotImplementedError()
  586. equal = (self.parent_seq_id == other.parent_seq_id
  587. and self.output_token == other.output_token)
  588. log_probs_equal = other.logprobs == self.logprobs
  589. return equal and log_probs_equal
  590. class SequenceGroupOutput(ABC):
  591. """The base class for model outputs associated with a sequence group."""
  592. @abstractmethod
  593. def __repr__(self) -> str:
  594. pass
  595. @abstractmethod
  596. def __eq__(self, other: object) -> bool:
  597. pass
  598. class CompletionSequenceGroupOutput(SequenceGroupOutput):
  599. """The model output associated with a completion sequence group."""
  600. def __init__(
  601. self,
  602. samples: List[SequenceOutput],
  603. prompt_logprobs: Optional[PromptLogprobs],
  604. ) -> None:
  605. self.samples = samples
  606. # Prompt logprob for each prompt query token.
  607. self.prompt_logprobs = prompt_logprobs
  608. def __repr__(self) -> str:
  609. return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
  610. f"prompt_logprobs={self.prompt_logprobs})")
  611. def __eq__(self, other: object) -> bool:
  612. if not isinstance(other, CompletionSequenceGroupOutput):
  613. raise NotImplementedError()
  614. return (self.samples == other.samples
  615. and self.prompt_logprobs == other.prompt_logprobs)
  616. class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
  617. """The model output associated with an embedding sequence group."""
  618. def __init__(
  619. self,
  620. embeddings: List[float],
  621. ) -> None:
  622. self.embeddings = embeddings
  623. def __repr__(self) -> str:
  624. return (f"EmbeddingSequenceGroupOutput("
  625. f"embeddings_shape={len(self.embeddings)})")
  626. def __eq__(self, other: object) -> bool:
  627. if not isinstance(other, EmbeddingSequenceGroupOutput):
  628. raise NotImplementedError()
  629. return self.embeddings == other.embeddings
  630. @dataclass
  631. class IntermediateTensors:
  632. """For all pipeline stages except the last, we need to return the hidden
  633. states and residuals to be sent to the next stage. This data structure
  634. contains the hidden states and residuals for a request.
  635. """
  636. tensors: Dict[str, torch.Tensor]
  637. def __getitem__(self, key: Union[str, slice]):
  638. if isinstance(key, str):
  639. return self.tensors[key]
  640. elif isinstance(key, slice):
  641. return self.__class__({k: v[key] for k, v in self.tensors.items()})
  642. def __setitem__(self, key: str, value):
  643. self.tensors[key] = value
  644. def __len__(self):
  645. return len(self.tensors)
  646. def __eq__(self, other: object):
  647. return isinstance(other, self.__class__) and self
  648. def __repr__(self) -> str:
  649. return f"IntermediateTensors(tensors={self.tensors})"
  650. @dataclass
  651. class SamplerOutput:
  652. """For each sequence group, we generate a list of SequenceOutput object,
  653. each of which contains one possible candidate for the next token.
  654. This data structure implements methods, so it can be used like a list, but
  655. also has optional fields for device tensors.
  656. """
  657. outputs: List[CompletionSequenceGroupOutput]
  658. # On-device tensor containing probabilities of each token.
  659. sampled_token_probs: Optional[torch.Tensor] = None
  660. # On-device tensor containing the logprobs of each token.
  661. logprobs: Optional["torch.Tensor"] = None
  662. # On-device tensor containing the sampled token ids.
  663. sampled_token_ids: Optional[torch.Tensor] = None
  664. # Spec decode metrics populated by workers.
  665. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  666. # Optional last hidden states from the model.
  667. hidden_states: Optional[torch.Tensor] = None
  668. def __getitem__(self, idx: int):
  669. return self.outputs[idx]
  670. def __setitem__(self, idx: int, value):
  671. self.outputs[idx] = value
  672. def __len__(self):
  673. return len(self.outputs)
  674. def __eq__(self, other: object):
  675. return isinstance(other,
  676. self.__class__) and self.outputs == other.outputs
  677. def __repr__(self) -> str:
  678. """Show the shape of a tensor instead of its values to reduce noise.
  679. """
  680. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  681. else self.sampled_token_probs.shape)
  682. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  683. self.sampled_token_ids.shape)
  684. return (
  685. f"SamplerOutput(outputs={self.outputs}, "
  686. f"sampled_token_probs={sampled_token_probs_repr}, "
  687. f"sampled_token_ids={sampled_token_ids_repr}, "
  688. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
  689. @dataclass
  690. class PoolerOutput:
  691. """The output from a pooling operation in the embedding model."""
  692. outputs: List[EmbeddingSequenceGroupOutput]
  693. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  694. def __getitem__(self, idx: int):
  695. return self.outputs[idx]
  696. def __setitem__(self, idx: int, value):
  697. self.outputs[idx] = value
  698. def __len__(self):
  699. return len(self.outputs)
  700. def __eq__(self, other: object):
  701. return isinstance(other,
  702. self.__class__) and self.outputs == other.outputs
  703. def get_all_seq_ids(
  704. seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
  705. """Given a list of SequenceGroupMetadata, create a list of all
  706. sequence ids.
  707. """
  708. return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
  709. class HiddenStates:
  710. """Hidden states corresponding to in-progress sequences.
  711. Used in speculative decoding to pass hidden states from
  712. the target model to the proposer model in the subsequent step.
  713. seq_ids are the sequence ids of each entry of the batch
  714. dimension of the hidden_states tensor"""
  715. def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
  716. hidden_states: torch.Tensor):
  717. assert len(seq_group_metadata_list) == len(hidden_states)
  718. self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
  719. self.hidden_states: torch.Tensor = hidden_states
  720. def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
  721. hidden_states: torch.Tensor) -> None:
  722. """Update hidden states from target model invocation."""
  723. assert len(seq_group_metadata_list) == len(hidden_states)
  724. self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
  725. self.hidden_states = torch.cat([self.hidden_states, hidden_states])
  726. def prune(self,
  727. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  728. """Prune to provided list of sequence ids."""
  729. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  730. if seq_ids != self.seq_ids:
  731. # Batch contents changed - prune removed sequences.
  732. index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
  733. self.hidden_states = self.hidden_states[index]
  734. self.seq_ids = seq_ids
  735. @dataclass
  736. class ExecuteModelRequest:
  737. """The model execution request, containing CPU metadata only. The LLM
  738. engine should create an instance of this class for each request batch."""
  739. # The sequence group metadata list.
  740. seq_group_metadata_list: List[SequenceGroupMetadata]
  741. # Blocks to swap in. List of CPU -> GPU block number.
  742. blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
  743. # Blocks to swap out. List of GPU -> CPU block number.
  744. blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
  745. # Blocks to copy. Source to dest block.
  746. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
  747. # Virtual engine ID for pipeline parallel.
  748. virtual_engine: int = 0
  749. # The number of slots for lookahead decoding.
  750. num_lookahead_slots: int = 0
  751. # The number of requests in the running queue.
  752. running_queue_size: int = 0
  753. # Optional hidden states from prior step.
  754. previous_hidden_states: Optional[HiddenStates] = None
  755. # The number of forward steps to run.
  756. num_steps: int = 1
  757. # Finished request ids since last step.
  758. finished_requests_ids: List[str] = field(default_factory=list)
  759. def clone(
  760. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  761. ) -> "ExecuteModelRequest":
  762. """Clone the request with a new sequence group metadata list."""
  763. return ExecuteModelRequest(
  764. seq_group_metadata_list=seq_group_metadata_list,
  765. blocks_to_swap_in=self.blocks_to_swap_in.copy(),
  766. blocks_to_swap_out=self.blocks_to_swap_out.copy(),
  767. blocks_to_copy=self.blocks_to_copy.copy(),
  768. virtual_engine=self.virtual_engine,
  769. num_lookahead_slots=self.num_lookahead_slots,
  770. running_queue_size=self.running_queue_size,
  771. previous_hidden_states=self.previous_hidden_states,
  772. num_steps=self.num_steps,
  773. finished_requests_ids=self.finished_requests_ids,
  774. )