sequence.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from abc import ABC, abstractmethod
  5. from array import array
  6. from collections import defaultdict
  7. from dataclasses import dataclass, field
  8. from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast
  9. import numpy
  10. import torch
  11. from aphrodite.common.pooling_params import PoolingParams
  12. from aphrodite.common.sampling_params import SamplingParams
  13. from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  16. if TYPE_CHECKING:
  17. from aphrodite.inputs import LLMInputs
  18. from aphrodite.multimodal import MultiModalDataDict
  19. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  20. @dataclass
  21. class Logprob:
  22. """Infos for supporting OpenAI compatible logprobs and token ranks.
  23. Attributes:
  24. logprob: The logprob of chosen token
  25. rank: The vocab rank of chosen token (>=1)
  26. decoded_token: The decoded chosen token index
  27. """
  28. logprob: float
  29. rank: Optional[int] = None
  30. decoded_token: Optional[str] = None
  31. # {token_id -> logprob} per each sequence group. None if the corresponding
  32. # sequence group doesn't require prompt logprob.
  33. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  34. # {token_id -> logprob} for each sequence group.
  35. SampleLogprobs = List[Dict[int, Logprob]]
  36. class SequenceStatus(enum.IntEnum):
  37. """Status of a sequence."""
  38. WAITING = 0
  39. RUNNING = 1
  40. SWAPPED = 2
  41. # Note: anything after SWAPPED (2) will be considered
  42. # as a finished status.
  43. FINISHED_STOPPED = 3
  44. FINISHED_LENGTH_CAPPED = 4
  45. FINISHED_ABORTED = 5
  46. FINISHED_IGNORED = 6
  47. @staticmethod
  48. def is_finished(status: "SequenceStatus") -> bool:
  49. return status > SequenceStatus.SWAPPED
  50. @staticmethod
  51. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  52. if status == SequenceStatus.FINISHED_STOPPED:
  53. finish_reason = "stop"
  54. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  55. finish_reason = "length"
  56. elif status == SequenceStatus.FINISHED_ABORTED:
  57. finish_reason = "abort"
  58. elif status == SequenceStatus.FINISHED_IGNORED:
  59. # The ignored sequences are the sequences whose prompt lengths
  60. # are longer than the model's length cap. Therefore, the stop
  61. # reason should also be "length" as in OpenAI API.
  62. finish_reason = "length"
  63. else:
  64. finish_reason = None
  65. return finish_reason
  66. class SequenceStage(enum.Enum):
  67. PREFILL = enum.auto()
  68. DECODE = enum.auto()
  69. @dataclass
  70. class RequestMetrics:
  71. """Metrics associated with a request.
  72. Attributes:
  73. arrival_time: The time when the request arrived.
  74. first_scheduled_time: The time when the request was first scheduled.
  75. first_token_time: The time when the first token was generated.
  76. time_in_queue: The time the request spent in the queue.
  77. finished_time: The time when the request was finished.
  78. """
  79. arrival_time: float
  80. last_token_time: float
  81. first_scheduled_time: Optional[float]
  82. first_token_time: Optional[float]
  83. time_in_queue: Optional[float]
  84. finished_time: Optional[float] = None
  85. class SequenceData:
  86. """Data associated with a sequence.
  87. Args:
  88. prompt_token_ids: The token IDs of the prompt.
  89. output_token_ids: The token IDs of the output. Set to an empty list if
  90. None.
  91. Attributes:
  92. prompt_token_ids: The token IDs of the prompt.
  93. output_token_ids: The token IDs of the output.
  94. cumulative_logprob: The cumulative log probability of the output.
  95. """
  96. def __init__(
  97. self,
  98. prompt_token_ids: List[int],
  99. output_token_ids: Optional[List[int]] = None,
  100. ) -> None:
  101. self._prompt_token_ids = array('l', prompt_token_ids)
  102. self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
  103. self._output_token_ids = array(
  104. 'l', output_token_ids if output_token_ids is not None else [])
  105. self.cumulative_logprob = 0.0
  106. # The number of tokens that are computed (that run against the model).
  107. self._num_computed_tokens = 0
  108. self._stage: SequenceStage = SequenceStage.PREFILL
  109. self._update_cached_all_tokens()
  110. def _update_cached_all_tokens(self):
  111. self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
  112. self._output_token_ids)
  113. @property
  114. def prompt_token_ids(self) -> Tuple[int, ...]:
  115. return self._prompt_token_ids_tuple
  116. @prompt_token_ids.setter
  117. def prompt_token_ids(self, new_prompt_token_ids) -> None:
  118. self._prompt_token_ids = array('l', new_prompt_token_ids)
  119. self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
  120. self._update_cached_all_tokens()
  121. @property
  122. def prompt_token_ids_array(self) -> array:
  123. return self._prompt_token_ids
  124. @property
  125. def output_token_ids(self) -> Tuple[int, ...]:
  126. return tuple(self._output_token_ids)
  127. @output_token_ids.setter
  128. def output_token_ids(self, new_output_token_ids) -> None:
  129. self._output_token_ids = array('l', new_output_token_ids)
  130. self._update_cached_all_tokens()
  131. @property
  132. def output_token_ids_array(self) -> array:
  133. return self._output_token_ids
  134. def append_token_id(self, token_id: int, logprob: float) -> None:
  135. self._output_token_ids.append(token_id)
  136. self._cached_all_token_ids.append(token_id)
  137. self.cumulative_logprob += logprob
  138. def get_len(self) -> int:
  139. return len(self._output_token_ids) + len(self._prompt_token_ids)
  140. def get_prompt_len(self) -> int:
  141. return len(self._prompt_token_ids)
  142. def get_output_len(self) -> int:
  143. return len(self._output_token_ids)
  144. def get_token_ids(self) -> List[int]:
  145. return self._cached_all_token_ids
  146. def get_prefix_token_ids(
  147. self, num_tokens: int
  148. ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
  149. """Get prefix tokens, and make the return value hashable"""
  150. prompt_length = self.get_prompt_len()
  151. if num_tokens > prompt_length:
  152. return (self._prompt_token_ids_tuple,
  153. tuple(self._output_token_ids[:num_tokens - prompt_length]))
  154. else:
  155. return (self._prompt_token_ids_tuple[:num_tokens], None)
  156. def get_num_computed_tokens(self) -> int:
  157. """Return the number of prefill tokens that are already computed."""
  158. return self._num_computed_tokens
  159. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  160. """Update number of tokens computed so far."""
  161. self._num_computed_tokens += num_new_computed_tokens
  162. assert self._num_computed_tokens <= self.get_len(), (
  163. self._num_computed_tokens, self.get_len())
  164. # If all tokens are computed, it means it is in decoding phase.
  165. if self.get_num_uncomputed_tokens() == 0:
  166. self._stage = SequenceStage.DECODE
  167. def reset_state_for_recompute(self) -> None:
  168. """Reset the number of computed tokens from this sequence. It is
  169. supposed to be called when a sequence needs to be started from
  170. the beginning again (e.g., sequence is preempted).
  171. """
  172. self._num_computed_tokens = 0
  173. self._stage = SequenceStage.PREFILL
  174. def get_num_uncomputed_tokens(self) -> int:
  175. """Return the number of prefill tokens that are not computed."""
  176. # we use `get_len()` which includes prompt_len + output_len instead
  177. # of prompt_len here. This is because during recompute we need to
  178. # prefill for both prompt and output.
  179. return self.get_len() - self.get_num_computed_tokens()
  180. def get_last_token_id(self) -> int:
  181. if not self._output_token_ids:
  182. return self._prompt_token_ids[-1]
  183. return self._output_token_ids[-1]
  184. def get_prompt_token_ids(self) -> Tuple[int, ...]:
  185. return self.prompt_token_ids
  186. def get_output_token_ids(self) -> Tuple[int, ...]:
  187. return self.output_token_ids
  188. @property
  189. def stage(self) -> SequenceStage:
  190. return self._stage
  191. def __repr__(self) -> str:
  192. return (f"SequenceData("
  193. f"prompt_token_ids={self._prompt_token_ids}, "
  194. f"output_token_ids={self._output_token_ids}, "
  195. f"cumulative_logprob={self.cumulative_logprob})")
  196. class Sequence:
  197. """Stores the data, status, and block information of a sequence.
  198. The sequence is constructed from the LLMInputs instance passed
  199. in through the `inputs` constructor argument.
  200. For encoder/decoder models, LLMInputs encapsulates both a
  201. decoder and encoder prompt, creating an ambiguity about which
  202. prompt to construct the sequence from. The `from_decoder_prompt`
  203. constructor argument signals whether to construct the Sequence
  204. from the LLMInputs decoder prompt, or encoder prompt.
  205. Args:
  206. seq_id: The ID of the sequence.
  207. inputs: The inputs of the sequence.
  208. block_size: The block size of the sequence. Should be the same as the
  209. block size used by the block manager and cache engine.
  210. eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
  211. lora_request: LoRA request.
  212. prompt_adapter_request: Prompt Adapter request.
  213. from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
  214. (True) or encoder prompt (False.) Must be True
  215. for decoder-only model.
  216. """
  217. def __init__(
  218. self,
  219. seq_id: int,
  220. inputs: "LLMInputs",
  221. block_size: int,
  222. eos_token_id: Optional[int] = None,
  223. lora_request: Optional[LoRARequest] = None,
  224. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  225. from_decoder_prompt: bool = True,
  226. ) -> None:
  227. self.seq_id = seq_id
  228. self.inputs = inputs
  229. self.block_size = block_size
  230. self.eos_token_id = eos_token_id
  231. self.lora_request = lora_request
  232. self.prompt_adapter_request = prompt_adapter_request
  233. self.from_decoder_prompt = from_decoder_prompt
  234. self._prompt: Optional[str] = None
  235. self._prompt_token_ids: Optional[List[int]] = None
  236. # For decoder-only models, a Sequence is constructed
  237. # from an LLMInputs instance (the `inputs` arg.)
  238. #
  239. # For encoder/decoder models the same `inputs`
  240. # instance could be utilized to construct either an
  241. # encoder sequence or a decoder sequence, because
  242. # `LLMInputs` has both decoder- and encoder-oriented
  243. # member variables (i.e. it encapsulates both an encoder
  244. # and a decoder prompt.) The decision of which type of sequence
  245. # to generate is determined by the `from_decoder_prompt` argument.
  246. #
  247. # When constructing a encoder sequence
  248. # (`from_decoder_prompt` False) it matters that
  249. # the `LLMInputs` instance stored in `inputs` is valid
  250. # in the sense that its encoder-related member variables are
  251. # populated; below, an exception is raised if this is
  252. # not the case.
  253. #
  254. # When constructing a decoder sequence (`from_decoder_prompt` True)
  255. # it does not matter whether `inputs` has its encoder-related
  256. # member variables populated.
  257. if not (from_decoder_prompt
  258. or is_valid_encoder_decoder_llm_inputs(inputs)):
  259. raise ValueError("Cannot extract encoder input prompt from "
  260. f"invalid input {inputs}; did you forget the "
  261. "encoder input prompt fields?")
  262. self.data = SequenceData(self.prompt_token_ids)
  263. self.output_logprobs: SampleLogprobs = []
  264. self.output_text = ""
  265. self.status = SequenceStatus.WAITING
  266. self.stop_reason: Union[int, str, None] = None
  267. # Used for incremental detokenization
  268. self.prefix_offset = 0
  269. self.read_offset = 0
  270. # Input + output tokens
  271. self.tokens: Optional[List[str]] = None
  272. @property
  273. def n_blocks(self) -> int:
  274. return (self.get_len() + self.block_size - 1) // self.block_size
  275. @property
  276. def prompt(self) -> Optional[str]:
  277. if self._prompt is not None:
  278. # Reuse precomputed prompt string
  279. return self._prompt
  280. # Select decoder or encoder input prompt str,
  281. # as appropriate
  282. prompt_key: str = ("prompt"
  283. if self.from_decoder_prompt else "encoder_prompt")
  284. # Cache prompt
  285. self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
  286. return self._prompt
  287. @property
  288. def prompt_token_ids(self) -> List[int]:
  289. if self._prompt_token_ids is not None:
  290. # Reuse precomputed prompt token ids
  291. return self._prompt_token_ids
  292. # Select decoder or encoder input prompt
  293. # token ids, as appropriate
  294. prompt_token_ids_key: str = ("prompt_token_ids"
  295. if self.from_decoder_prompt else
  296. "encoder_prompt_token_ids")
  297. # Cache computed prompt token ids
  298. self._prompt_token_ids = cast(List[int],
  299. self.inputs.get(prompt_token_ids_key))
  300. return self._prompt_token_ids
  301. @property
  302. def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
  303. return self.inputs.get("multi_modal_data")
  304. @property
  305. def lora_int_id(self) -> int:
  306. return self.lora_request.lora_int_id if self.lora_request else 0
  307. @property
  308. def prompt_adapter_id(self) -> int:
  309. return self.prompt_adapter_request.prompt_adapter_id \
  310. if self.prompt_adapter_request else 0
  311. def get_output_text_to_return(self, buffer_length: int):
  312. # We return the full output text if the sequence is finished.
  313. truncate = buffer_length and not self.is_finished()
  314. return self.output_text[:-buffer_length] if truncate else (
  315. self.output_text)
  316. def hash_of_block(self, logical_idx: int) -> int:
  317. # TODO This can produce incorrect hash when block size > prompt size
  318. # Compute the number of tokens in the sequence
  319. # TODO: The current hashing function is O(L^2). We should optimize
  320. # this in the future.
  321. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  322. hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
  323. return hash((hashed_tokens, self.lora_int_id))
  324. def num_hashed_tokens_of_block(self, logical_idx: int):
  325. return logical_idx * self.block_size + self.block_size
  326. def reset_state_for_recompute(self):
  327. """Reset the sequence states for recomputation."""
  328. self.data.reset_state_for_recompute()
  329. def append_token_id(
  330. self,
  331. token_id: int,
  332. logprobs: Dict[int, Logprob],
  333. ) -> None:
  334. assert token_id in logprobs
  335. self.output_logprobs.append(logprobs)
  336. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  337. def get_len(self) -> int:
  338. return self.data.get_len()
  339. def get_prompt_len(self) -> int:
  340. return self.data.get_prompt_len()
  341. def get_output_len(self) -> int:
  342. return self.data.get_output_len()
  343. def get_token_ids(self) -> List[int]:
  344. return self.data.get_token_ids()
  345. def get_prompt_token_ids(self) -> Tuple[int, ...]:
  346. return self.data.get_prompt_token_ids()
  347. def get_last_token_id(self) -> int:
  348. return self.data.get_last_token_id()
  349. def get_output_token_ids(self) -> Tuple[int, ...]:
  350. return self.data.get_output_token_ids()
  351. def get_cumulative_logprob(self) -> float:
  352. return self.data.cumulative_logprob
  353. def get_beam_search_score(self,
  354. length_penalty: float = 1.0,
  355. seq_len: Optional[int] = None,
  356. eos_token_id: Optional[int] = None) -> float:
  357. """Calculate the beam search score with length penalty.
  358. Adapted from
  359. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  360. """
  361. if seq_len is None:
  362. seq_len = self.get_len()
  363. # NOTE: HF implementation does not count the EOS token
  364. # towards the length, we align with that here for testing.
  365. if (eos_token_id is not None
  366. and self.get_last_token_id() == eos_token_id):
  367. seq_len -= 1
  368. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  369. def is_finished(self) -> bool:
  370. return SequenceStatus.is_finished(self.status)
  371. def fork(self, new_seq_id: int) -> "Sequence":
  372. new_seq = copy.deepcopy(self)
  373. new_seq.seq_id = new_seq_id
  374. return new_seq
  375. def get_num_new_tokens(self) -> int:
  376. """Get the number of new tokens to be computed.
  377. Returns:
  378. The new number of tokens to be computed. I.e., 1 for decode, or
  379. the remaining prompt size for prefill.
  380. """
  381. if self.data.stage == SequenceStage.DECODE:
  382. return 1
  383. return self.data.get_num_uncomputed_tokens()
  384. def is_prefill(self) -> bool:
  385. return self.data.stage == SequenceStage.PREFILL
  386. def __repr__(self) -> str:
  387. return (f"Sequence(seq_id={self.seq_id}, "
  388. f"status={self.status.name}, "
  389. f"num_blocks={self.n_blocks}, ")
  390. @dataclass
  391. class SequenceGroupState:
  392. """Mutable state tied to a specific sequence group"""
  393. # for multi-step decoding
  394. num_steps: int = 1
  395. current_step: int = 0
  396. @property
  397. def remaining_steps(self) -> int:
  398. return self.num_steps - self.current_step
  399. class SequenceGroup:
  400. """A group of sequences that are generated from the same prompt.
  401. Args:
  402. request_id: The ID of the request.
  403. seqs: The list of sequences.
  404. sampling_params: The sampling parameters used to generate the outputs.
  405. arrival_time: The arrival time of the request.
  406. lora_request: LoRA request.
  407. embeddings: The embeddings vectors of the prompt of the sequence group
  408. for an embedding model.
  409. pooling_params: The pooling parameters used to generate the pooling
  410. for an embedding model.
  411. encoder_seq: Optional, the single encoder sequence. Should be None
  412. unless you are working with an encoder/decoder model.
  413. prompt_adapter_request: Prompt adapter request.
  414. """
  415. def __init__(
  416. self,
  417. request_id: str,
  418. seqs: List[Sequence],
  419. arrival_time: float,
  420. sampling_params: Optional[SamplingParams] = None,
  421. lora_request: Optional[LoRARequest] = None,
  422. embeddings: Optional[List[float]] = None,
  423. pooling_params: Optional[PoolingParams] = None,
  424. encoder_seq: Optional[Sequence] = None,
  425. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  426. ) -> None:
  427. self.request_id = request_id
  428. self.seqs = seqs
  429. self.is_single_seq = len(seqs) == 1
  430. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  431. self.sampling_params = sampling_params
  432. self.metrics = RequestMetrics(arrival_time=arrival_time,
  433. last_token_time=arrival_time,
  434. first_scheduled_time=None,
  435. first_token_time=None,
  436. time_in_queue=None)
  437. self.lora_request = lora_request
  438. self.prompt_logprobs: Optional[PromptLogprobs] = None
  439. self.state = SequenceGroupState()
  440. self.embeddings = embeddings
  441. self.pooling_params = pooling_params
  442. self.prompt_adapter_request = prompt_adapter_request
  443. self.encoder_seq = encoder_seq
  444. @property
  445. def prompt(self) -> Optional[str]:
  446. # All sequences in the group should have the same prompt.
  447. # We use the prompt of an arbitrary sequence.
  448. return self.seqs[0].prompt
  449. @property
  450. def prompt_token_ids(self) -> List[int]:
  451. # All sequences in the group should have the same prompt.
  452. # We use the prompt of an arbitrary sequence.
  453. return self.seqs[0].prompt_token_ids
  454. @property
  455. def encoder_prompt(self) -> Optional[str]:
  456. # There are either 0 or 1 encoder sequences
  457. # If one is present, its prompt is distinct
  458. # from the decoder's.
  459. return (self.encoder_seq.prompt
  460. if self.encoder_seq is not None else None)
  461. @property
  462. def encoder_prompt_token_ids(self) -> Optional[List[int]]:
  463. # There are either 0 or 1 encoder sequences
  464. # If one is present, its prompt token ids are
  465. # distinct from the decoder's.
  466. return (self.encoder_seq.prompt_token_ids
  467. if self.encoder_seq is not None else None)
  468. @property
  469. def multi_modal_data(self) -> "MultiModalDataDict":
  470. # All sequences in the group should have the same multi-modal data.
  471. # We use the multi-modal data of an arbitrary sequence.
  472. return self.seqs[0].multi_modal_data
  473. @property
  474. def lora_int_id(self) -> int:
  475. return self.lora_request.lora_int_id if self.lora_request else 0
  476. @property
  477. def prompt_adapter_id(self) -> int:
  478. return self.prompt_adapter_request.prompt_adapter_id \
  479. if self.prompt_adapter_request else 0
  480. @property
  481. def prompt_adapter_num_virtual_tokens(self) -> int:
  482. return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
  483. if self.prompt_adapter_request else 0
  484. def init_multi_step(self, num_scheduler_steps: int) -> None:
  485. self.state.num_steps = num_scheduler_steps
  486. self.state.current_step = 0
  487. def get_last_latency(self, now: float) -> Optional[float]:
  488. """Sets the last token time for Request level timings."""
  489. # If still in prefill phase, raise Error.
  490. if self.is_prefill():
  491. raise ValueError(
  492. "seq_group.get_last_latency() should not be called "
  493. "if the seq_group is in prefill phase.")
  494. # Otherwise return token latency.
  495. latency = now - self.metrics.last_token_time
  496. self.metrics.last_token_time = now
  497. return latency
  498. def maybe_set_first_token_time(self, time: float) -> None:
  499. """Sets the first token time for Request level timings."""
  500. # NOTE: in a case where a sequence_group is swapped and
  501. # recomputed, the time between iterations is counted
  502. # in TPOT, rather than recalculating TTFT (since from the )
  503. # POV of the user, there is simply a long generation delay.
  504. if (self.metrics.first_token_time is None
  505. and self.seqs[0].get_output_len() == 1):
  506. self.metrics.first_token_time = time
  507. def maybe_set_first_scheduled_time(self, time: float) -> None:
  508. """Sets the first scheduled time and time in queue for Request
  509. level timings."""
  510. if self.metrics.first_scheduled_time is None:
  511. self.metrics.first_scheduled_time = time
  512. self.metrics.time_in_queue = time - self.metrics.arrival_time
  513. def set_finished_time(self, time: Optional[float]) -> None:
  514. """Sets the finished time for Request level timings."""
  515. self.metrics.finished_time = time
  516. def get_max_num_running_seqs(self) -> int:
  517. """The maximum number of sequences running in parallel in the remaining
  518. lifetime of the request."""
  519. if self.sampling_params and self.sampling_params.use_beam_search:
  520. # For beam search, maximally there will always be `best_of` beam
  521. # candidates running in the future.
  522. return self.sampling_params.best_of
  523. else:
  524. if (self.sampling_params
  525. and self.sampling_params.best_of > self.num_seqs()):
  526. # At prompt stage, the sequence group is not yet filled up
  527. # and only have one sequence running. However, in the
  528. # generation stage, we will have `best_of` sequences running.
  529. return self.sampling_params.best_of
  530. # At sampling stages, return the number of actual sequences
  531. # that are not finished yet.
  532. return self.num_unfinished_seqs()
  533. def get_seqs(
  534. self,
  535. status: Optional[SequenceStatus] = None,
  536. ) -> List[Sequence]:
  537. if status is None:
  538. return self.seqs
  539. if self.is_single_seq:
  540. return self.seqs if self.seqs[0].status == status else []
  541. return [seq for seq in self.seqs if seq.status == status]
  542. def is_encoder_decoder(self) -> bool:
  543. return self.encoder_seq is not None
  544. def get_encoder_seq(self) -> Optional[Sequence]:
  545. return self.encoder_seq
  546. def get_unfinished_seqs(self) -> List[Sequence]:
  547. if self.is_single_seq:
  548. return self.seqs if not self.seqs[0].is_finished() else []
  549. return [seq for seq in self.seqs if not seq.is_finished()]
  550. def get_finished_seqs(self) -> List[Sequence]:
  551. if self.is_single_seq:
  552. return self.seqs if self.seqs[0].is_finished() else []
  553. return [seq for seq in self.seqs if seq.is_finished()]
  554. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  555. """Update number of tokens computed so far."""
  556. for seq in self.seqs:
  557. if not seq.is_finished():
  558. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  559. def get_num_uncomputed_tokens(self) -> int:
  560. num_uncomputed_tokens = 0
  561. for seq in self.seqs:
  562. if not seq.is_finished():
  563. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  564. return num_uncomputed_tokens
  565. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  566. # Optimization. We don't need to call get_seqs if we don't need to
  567. # filter by states.
  568. if status is None:
  569. return len(self.seqs)
  570. if self.is_single_seq:
  571. return 1 if self.seqs[0].status == status else 0
  572. return len(self.get_seqs(status))
  573. def num_unfinished_seqs(self) -> int:
  574. if self.is_single_seq:
  575. return 1 if not self.seqs[0].is_finished() else 0
  576. return len(self.get_unfinished_seqs())
  577. def num_finished_seqs(self) -> int:
  578. if self.is_single_seq:
  579. return 1 if self.seqs[0].is_finished() else 0
  580. return len(self.get_finished_seqs())
  581. def find(self, seq_id: int) -> Sequence:
  582. if seq_id not in self.seqs_dict:
  583. raise ValueError(f"Sequence {seq_id} not found.")
  584. return self.seqs_dict[seq_id]
  585. def add(self, seq: Sequence) -> None:
  586. if seq.seq_id in self.seqs_dict:
  587. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  588. self.seqs_dict[seq.seq_id] = seq
  589. self.seqs.append(seq)
  590. self.is_single_seq = len(self.seqs) == 1
  591. def remove(self, seq_id: int) -> None:
  592. seq = self.seqs_dict.pop(seq_id, None)
  593. if seq is None:
  594. raise ValueError(f"Sequence {seq_id} not found.")
  595. self.seqs.remove(seq)
  596. self.is_single_seq = len(self.seqs) == 1
  597. def is_finished(self) -> bool:
  598. return all(seq.is_finished() for seq in self.seqs)
  599. def is_prefill(self) -> bool:
  600. # Every sequence should be in the same stage.
  601. return self.seqs[0].is_prefill()
  602. def __repr__(self) -> str:
  603. return (f"SequenceGroup(request_id={self.request_id}, "
  604. f"sampling_params={self.sampling_params}, "
  605. f"num_seqs={len(self.seqs)})")
  606. class SequenceGroupMetadata:
  607. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  608. Args:
  609. request_id: The ID of the request.
  610. is_prompt: Whether the request is at prompt stage.
  611. seq_data: The sequence data. (Seq id -> sequence data)
  612. sampling_params: The sampling parameters used to generate the outputs.
  613. block_tables: The block tables. (Seq id -> list of physical block
  614. numbers)
  615. do_sample: True if sampling is required. Sampling is not required when
  616. e.g., prefill is chunked, and the current iteration only computes
  617. query tokens for prefill, we don't need sampling.
  618. token_chunk_size: The number of tokens to be processed (per sequence).
  619. None if chunking is not required.
  620. lora_request: LoRA request.
  621. state: Internal state tied to this sequence group.
  622. computed_block_nums: The block numbers that are already computed,
  623. used in prefix caching.
  624. multi_modal_data: Multi modal data.
  625. encoder_seq_data: Optional sequence data for encoder prompt
  626. (SequenceGroup.encoder_seq). Should be None
  627. unless you are working with an encoder/decoder
  628. model.
  629. cross_block_table: Optional cross-attention block table associated
  630. with the encoder prompt
  631. (SequenceGroup.encoder_seq). Should be None
  632. unless you are working with an encoder/decoder
  633. model.
  634. prompt_adapter_request: Prompt Adapter request.
  635. """
  636. def __init__(
  637. self,
  638. request_id: str,
  639. is_prompt: bool,
  640. seq_data: Dict[int, SequenceData],
  641. sampling_params: SamplingParams,
  642. block_tables: Dict[int, List[int]],
  643. do_sample: bool = True,
  644. pooling_params: Optional[PoolingParams] = None,
  645. token_chunk_size: Optional[int] = None,
  646. lora_request: Optional[LoRARequest] = None,
  647. computed_block_nums: Optional[List[int]] = None,
  648. state: Optional[SequenceGroupState] = None,
  649. multi_modal_data: Optional["MultiModalDataDict"] = None,
  650. encoder_seq_data: Optional[SequenceData] = None,
  651. cross_block_table: Optional[List[int]] = None,
  652. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  653. ) -> None:
  654. self.request_id = request_id
  655. self.is_prompt = is_prompt
  656. self.seq_data = seq_data
  657. self.sampling_params = sampling_params
  658. self.block_tables = block_tables
  659. self.pooling_params = pooling_params
  660. self.lora_request = lora_request
  661. self.prompt_adapter_request = prompt_adapter_request
  662. self.computed_block_nums = computed_block_nums
  663. self.multi_modal_data = multi_modal_data
  664. self.state = SequenceGroupState() if state is None else state
  665. self.encoder_seq_data = encoder_seq_data
  666. self.cross_block_table = cross_block_table
  667. self._token_chunk_size = token_chunk_size
  668. self.do_sample = do_sample
  669. # The number of speculative tokens adopted in this request.
  670. # None means specuative decoding is not used.
  671. # Zero means speculative decoding is disabled for some reasons.
  672. # TODO: We should maintain this states out of the sequence group.
  673. self.num_speculative_tokens = None
  674. if seq_data is not None and self._token_chunk_size is None:
  675. if is_prompt:
  676. self._token_chunk_size = next(iter(
  677. seq_data.values())).get_len()
  678. else:
  679. self._token_chunk_size = 1
  680. @property
  681. def lora_int_id(self) -> int:
  682. return self.lora_request.lora_int_id if self.lora_request else 0
  683. @property
  684. def prompt_adapter_id(self) -> int:
  685. return self.prompt_adapter_request.prompt_adapter_id \
  686. if self.prompt_adapter_request else 0
  687. @property
  688. def prompt_adapter_num_virtual_tokens(self) -> int:
  689. return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
  690. if self.prompt_adapter_request else 0
  691. @property
  692. def token_chunk_size(self) -> int:
  693. """Return the number of tokens to be processed (chunk size)."""
  694. assert self._token_chunk_size is not None
  695. return self._token_chunk_size
  696. def finish_step(self) -> None:
  697. assert self.state.current_step < self.state.num_steps
  698. self.state.current_step += 1
  699. class SequenceOutput:
  700. """The model output associated with a sequence.
  701. Args:
  702. parent_seq_id: The ID of the parent sequence (for forking in beam
  703. search).
  704. output_token: The output token ID.
  705. logprobs: The logprobs of the output token.
  706. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  707. """
  708. def __init__(
  709. self,
  710. parent_seq_id: int,
  711. output_token: int,
  712. logprobs: Dict[int, Logprob],
  713. ) -> None:
  714. self.parent_seq_id = parent_seq_id
  715. self.output_token = output_token
  716. self.logprobs = logprobs
  717. def __repr__(self) -> str:
  718. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  719. f"output_token={self.output_token}, "
  720. f"logprobs={self.logprobs})")
  721. def __eq__(self, other: object) -> bool:
  722. if not isinstance(other, SequenceOutput):
  723. raise NotImplementedError()
  724. equal = (self.parent_seq_id == other.parent_seq_id
  725. and self.output_token == other.output_token)
  726. log_probs_equal = other.logprobs == self.logprobs
  727. return equal and log_probs_equal
  728. class SequenceGroupOutput(ABC):
  729. """The base class for model outputs associated with a sequence group."""
  730. @abstractmethod
  731. def __repr__(self) -> str:
  732. pass
  733. @abstractmethod
  734. def __eq__(self, other: object) -> bool:
  735. pass
  736. class CompletionSequenceGroupOutput(SequenceGroupOutput):
  737. """The model output associated with a completion sequence group."""
  738. def __init__(
  739. self,
  740. samples: List[SequenceOutput],
  741. prompt_logprobs: Optional[PromptLogprobs],
  742. ) -> None:
  743. self.samples = samples
  744. # Prompt logprob for each prompt query token.
  745. self.prompt_logprobs = prompt_logprobs
  746. def __repr__(self) -> str:
  747. return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
  748. f"prompt_logprobs={self.prompt_logprobs})")
  749. def __eq__(self, other: object) -> bool:
  750. if not isinstance(other, CompletionSequenceGroupOutput):
  751. raise NotImplementedError()
  752. return (self.samples == other.samples
  753. and self.prompt_logprobs == other.prompt_logprobs)
  754. class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
  755. """The model output associated with an embedding sequence group."""
  756. def __init__(
  757. self,
  758. embeddings: List[float],
  759. ) -> None:
  760. self.embeddings = embeddings
  761. def __repr__(self) -> str:
  762. return (f"EmbeddingSequenceGroupOutput("
  763. f"embeddings_shape={len(self.embeddings)})")
  764. def __eq__(self, other: object) -> bool:
  765. if not isinstance(other, EmbeddingSequenceGroupOutput):
  766. raise NotImplementedError()
  767. return self.embeddings == other.embeddings
  768. @dataclass
  769. class IntermediateTensors:
  770. """For all pipeline stages except the last, we need to return the hidden
  771. states and residuals to be sent to the next stage. This data structure
  772. contains the hidden states and residuals for a request.
  773. """
  774. tensors: Dict[str, torch.Tensor]
  775. def __getitem__(self, key: Union[str, slice]):
  776. if isinstance(key, str):
  777. return self.tensors[key]
  778. elif isinstance(key, slice):
  779. return self.__class__({k: v[key] for k, v in self.tensors.items()})
  780. def __setitem__(self, key: str, value):
  781. self.tensors[key] = value
  782. def __len__(self):
  783. return len(self.tensors)
  784. def __eq__(self, other: object):
  785. return isinstance(other, self.__class__) and self
  786. def __repr__(self) -> str:
  787. return f"IntermediateTensors(tensors={self.tensors})"
  788. @dataclass
  789. class SamplerOutput:
  790. """For each sequence group, we generate a list of SequenceOutput object,
  791. each of which contains one possible candidate for the next token.
  792. This data structure implements methods, so it can be used like a list, but
  793. also has optional fields for device tensors.
  794. """
  795. outputs: List[CompletionSequenceGroupOutput]
  796. # On-device tensor containing probabilities of each token.
  797. sampled_token_probs: Optional[torch.Tensor] = None
  798. # On-device tensor containing the logprobs of each token.
  799. logprobs: Optional["torch.Tensor"] = None
  800. # On-device tensor containing the sampled token ids.
  801. sampled_token_ids: Optional[torch.Tensor] = None
  802. sampled_token_ids_numpy: Optional[numpy.ndarray] = None
  803. # Spec decode metrics populated by workers.
  804. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  805. # Optional last hidden states from the model.
  806. hidden_states: Optional[torch.Tensor] = None
  807. def __getitem__(self, idx: int):
  808. return self.outputs[idx]
  809. def __setitem__(self, idx: int, value):
  810. self.outputs[idx] = value
  811. def __len__(self):
  812. return len(self.outputs)
  813. def __eq__(self, other: object):
  814. return isinstance(other,
  815. self.__class__) and self.outputs == other.outputs
  816. def __repr__(self) -> str:
  817. """Show the shape of a tensor instead of its values to reduce noise.
  818. """
  819. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  820. else self.sampled_token_probs.shape)
  821. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  822. self.sampled_token_ids.shape)
  823. return (
  824. f"SamplerOutput(outputs={self.outputs}, "
  825. f"sampled_token_probs={sampled_token_probs_repr}, "
  826. f"sampled_token_ids={sampled_token_ids_repr}, "
  827. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
  828. @dataclass
  829. class PoolerOutput:
  830. """The output from a pooling operation in the embedding model."""
  831. outputs: List[EmbeddingSequenceGroupOutput]
  832. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  833. def __getitem__(self, idx: int):
  834. return self.outputs[idx]
  835. def __setitem__(self, idx: int, value):
  836. self.outputs[idx] = value
  837. def __len__(self):
  838. return len(self.outputs)
  839. def __eq__(self, other: object):
  840. return isinstance(other,
  841. self.__class__) and self.outputs == other.outputs
  842. def get_all_seq_ids(
  843. seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
  844. """Given a list of SequenceGroupMetadata, create a list of all
  845. sequence ids.
  846. """
  847. return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
  848. def get_all_seq_ids_and_request_ids(
  849. seq_group_metadata_list: List[SequenceGroupMetadata]
  850. ) -> Tuple[List[int], Dict[str, Set[int]]]:
  851. """Given a list of SequenceGroupMetadata, create a list of all
  852. sequence ids.
  853. """
  854. seq_ids: List[int] = []
  855. request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
  856. for sg in seq_group_metadata_list:
  857. for seq_id in sg.seq_data:
  858. seq_ids.append(seq_id)
  859. request_id_seq_ids_mapping[sg.request_id].add(seq_id)
  860. return seq_ids, request_id_seq_ids_mapping
  861. class HiddenStates:
  862. """Hidden states corresponding to in-progress sequences.
  863. Used in speculative decoding to pass hidden states from
  864. the target model to the proposer model in the subsequent step.
  865. seq_ids are the sequence ids of each entry of the batch
  866. dimension of the hidden_states tensor"""
  867. def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
  868. hidden_states: torch.Tensor):
  869. assert len(seq_group_metadata_list) == len(hidden_states)
  870. self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
  871. self.hidden_states: torch.Tensor = hidden_states
  872. def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
  873. hidden_states: torch.Tensor) -> None:
  874. """Update hidden states from target model invocation."""
  875. assert len(seq_group_metadata_list) == len(hidden_states)
  876. self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
  877. self.hidden_states = torch.cat([self.hidden_states, hidden_states])
  878. def prune(self,
  879. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  880. """Prune to provided list of sequence ids."""
  881. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  882. if seq_ids != self.seq_ids:
  883. # Batch contents changed - prune removed sequences.
  884. index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
  885. self.hidden_states = self.hidden_states[index]
  886. self.seq_ids = seq_ids
  887. @dataclass
  888. class ExecuteModelRequest:
  889. """The model execution request, containing CPU metadata only. The LLM
  890. engine should create an instance of this class for each request batch."""
  891. # The sequence group metadata list.
  892. seq_group_metadata_list: List[SequenceGroupMetadata]
  893. # Blocks to swap in. List of CPU -> GPU block number.
  894. blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
  895. # Blocks to swap out. List of GPU -> CPU block number.
  896. blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
  897. # Blocks to copy. Source to dest block.
  898. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
  899. # Virtual engine ID for pipeline parallel.
  900. virtual_engine: int = 0
  901. # The number of slots for lookahead decoding.
  902. num_lookahead_slots: int = 0
  903. # The number of requests in the running queue.
  904. running_queue_size: int = 0
  905. # Optional hidden states from prior step.
  906. previous_hidden_states: Optional[HiddenStates] = None
  907. # The number of forward steps to run.
  908. num_steps: int = 1
  909. # Finished request ids since last step.
  910. finished_requests_ids: List[str] = field(default_factory=list)
  911. # The last sampled token ids for multi step decoding.
  912. last_sampled_token_ids: Optional[torch.Tensor] = None
  913. @property
  914. def is_first_multi_step(self) -> bool:
  915. # TODO: make this be able to handle batches with variable number of
  916. # steps
  917. assert len(self.seq_group_metadata_list) > 0
  918. first_seq_group = self.seq_group_metadata_list[0]
  919. return first_seq_group.state.current_step == 0
  920. @property
  921. def is_last_step(self) -> bool:
  922. # TODO: make this be able to handle batches with variable number of
  923. # steps
  924. assert len(self.seq_group_metadata_list) > 0
  925. first_seq_group = self.seq_group_metadata_list[0]
  926. num_steps = first_seq_group.state.num_steps
  927. current_step = first_seq_group.state.current_step
  928. return num_steps - current_step == 1
  929. @property
  930. def current_step(self) -> int:
  931. # TODO: make this be able to handle batches with variable number of
  932. # steps
  933. assert len(self.seq_group_metadata_list) > 0
  934. return self.seq_group_metadata_list[0].state.current_step
  935. def clone(
  936. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  937. ) -> "ExecuteModelRequest":
  938. """Clone the request with a new sequence group metadata list."""
  939. return ExecuteModelRequest(
  940. seq_group_metadata_list=seq_group_metadata_list,
  941. blocks_to_swap_in=self.blocks_to_swap_in.copy(),
  942. blocks_to_swap_out=self.blocks_to_swap_out.copy(),
  943. blocks_to_copy=self.blocks_to_copy.copy(),
  944. virtual_engine=self.virtual_engine,
  945. num_lookahead_slots=self.num_lookahead_slots,
  946. running_queue_size=self.running_queue_size,
  947. previous_hidden_states=self.previous_hidden_states,
  948. num_steps=self.num_steps,
  949. finished_requests_ids=self.finished_requests_ids,
  950. last_sampled_token_ids=self.last_sampled_token_ids.clone()
  951. if self.last_sampled_token_ids is not None else None)