sequence.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. import math
  5. from abc import ABC, abstractmethod
  6. from collections import defaultdict
  7. from dataclasses import dataclass, field
  8. from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
  9. import torch
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.lora.request import LoRARequest
  13. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  14. if TYPE_CHECKING:
  15. from aphrodite.inputs import LLMInputs
  16. from aphrodite.multimodal import MultiModalDataDict
  17. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  18. @dataclass
  19. class Logprob:
  20. """Infos for supporting OpenAI compatible logprobs and token ranks.
  21. Attributes:
  22. logprob: The logprob of chosen token
  23. rank: The vocab rank of chosen token (>=1)
  24. decoded_token: The decoded chosen token index
  25. """
  26. logprob: float
  27. rank: Optional[int] = None
  28. decoded_token: Optional[str] = None
  29. # {token_id -> logprob} per each sequence group. None if the corresponding
  30. # sequence group doesn't require prompt logprob.
  31. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  32. # {token_id -> logprob} for each sequence group.
  33. SampleLogprobs = List[Dict[int, Logprob]]
  34. class SequenceStatus(enum.IntEnum):
  35. """Status of a sequence."""
  36. WAITING = 0
  37. RUNNING = 1
  38. SWAPPED = 2
  39. # Note: anything after SWAPPED (2) will be considered
  40. # as a finished status.
  41. FINISHED_STOPPED = 3
  42. FINISHED_LENGTH_CAPPED = 4
  43. FINISHED_ABORTED = 5
  44. FINISHED_IGNORED = 6
  45. @staticmethod
  46. def is_finished(status: "SequenceStatus") -> bool:
  47. return status > SequenceStatus.SWAPPED
  48. @staticmethod
  49. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  50. if status == SequenceStatus.FINISHED_STOPPED:
  51. finish_reason = "stop"
  52. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  53. finish_reason = "length"
  54. elif status == SequenceStatus.FINISHED_ABORTED:
  55. finish_reason = "abort"
  56. elif status == SequenceStatus.FINISHED_IGNORED:
  57. # The ignored sequences are the sequences whose prompt lengths
  58. # are longer than the model's length cap. Therefore, the stop
  59. # reason should also be "length" as in OpenAI API.
  60. finish_reason = "length"
  61. else:
  62. finish_reason = None
  63. return finish_reason
  64. class SequenceStage(enum.Enum):
  65. PREFILL = enum.auto()
  66. DECODE = enum.auto()
  67. @dataclass
  68. class RequestMetrics:
  69. """Metrics associated with a request.
  70. Attributes:
  71. arrival_time: The time when the request arrived.
  72. first_scheduled_time: The time when the request was first scheduled.
  73. first_token_time: The time when the first token was generated.
  74. time_in_queue: The time the request spent in the queue.
  75. finished_time: The time when the request was finished.
  76. """
  77. arrival_time: float
  78. last_token_time: float
  79. first_scheduled_time: Optional[float]
  80. first_token_time: Optional[float]
  81. time_in_queue: Optional[float]
  82. finished_time: Optional[float] = None
  83. class SequenceData:
  84. """Data associated with a sequence.
  85. Args:
  86. prompt_token_ids: The token IDs of the prompt.
  87. output_token_ids: The token IDs of the output. Set to an empty list if
  88. None.
  89. Attributes:
  90. prompt_token_ids: The token IDs of the prompt.
  91. output_token_ids: The token IDs of the output.
  92. cumulative_logprob: The cumulative log probability of the output.
  93. """
  94. def __init__(
  95. self,
  96. prompt_token_ids: List[int],
  97. output_token_ids: Optional[List[int]] = None,
  98. ) -> None:
  99. self._prompt_token_ids: List[int] = list(prompt_token_ids)
  100. self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
  101. self._output_token_ids: List[int] = (
  102. list(output_token_ids) if output_token_ids is not None else [])
  103. self.cumulative_logprob = 0.0
  104. # The number of tokens that are computed (that run against the model).
  105. self._num_computed_tokens = 0
  106. self._stage: SequenceStage = SequenceStage.PREFILL
  107. self._update_cached_all_tokens()
  108. def _update_cached_all_tokens(self):
  109. self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
  110. self._output_token_ids)
  111. @property
  112. def prompt_token_ids(self) -> Tuple[int, ...]:
  113. return self._prompt_token_ids_tuple
  114. @prompt_token_ids.setter
  115. def prompt_token_ids(self, new_prompt_token_ids) -> None:
  116. self._prompt_token_ids = list(new_prompt_token_ids)
  117. self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
  118. self._update_cached_all_tokens()
  119. @property
  120. def output_token_ids(self) -> Tuple[int, ...]:
  121. return tuple(self._output_token_ids)
  122. @output_token_ids.setter
  123. def output_token_ids(self, new_output_token_ids) -> None:
  124. self._output_token_ids = list(new_output_token_ids)
  125. self._update_cached_all_tokens()
  126. def append_token_id(self, token_id: int, logprob: float) -> None:
  127. self._output_token_ids.append(token_id)
  128. self._cached_all_token_ids.append(token_id)
  129. self.cumulative_logprob += logprob
  130. def get_len(self) -> int:
  131. return len(self._output_token_ids) + len(self._prompt_token_ids)
  132. def get_prompt_len(self) -> int:
  133. return len(self._prompt_token_ids)
  134. def get_output_len(self) -> int:
  135. return len(self._output_token_ids)
  136. def get_token_ids(self) -> List[int]:
  137. return self._cached_all_token_ids
  138. def get_prefix_token_ids(
  139. self, num_tokens: int
  140. ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
  141. """Get prefix tokens, and make the return value hashable"""
  142. prompt_length = self.get_prompt_len()
  143. if num_tokens > prompt_length:
  144. return (self._prompt_token_ids_tuple,
  145. tuple(self._output_token_ids[:num_tokens - prompt_length]))
  146. else:
  147. return (self._prompt_token_ids_tuple[:num_tokens], None)
  148. def get_num_computed_tokens(self) -> int:
  149. """Return the number of prefill tokens that are already computed."""
  150. return self._num_computed_tokens
  151. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  152. """Update number of tokens computed so far."""
  153. self._num_computed_tokens += num_new_computed_tokens
  154. assert self._num_computed_tokens <= self.get_len(), (
  155. self._num_computed_tokens, self.get_len())
  156. # If all tokens are computed, it means it is in decoding phase.
  157. if self.get_num_uncomputed_tokens() == 0:
  158. self._stage = SequenceStage.DECODE
  159. def reset_state_for_recompute(self) -> None:
  160. """Reset the number of computed tokens from this sequence. It is
  161. supposed to be called when a sequence needs to be started from
  162. the beginning again (e.g., sequence is preempted).
  163. """
  164. self._num_computed_tokens = 0
  165. self._stage = SequenceStage.PREFILL
  166. def get_num_uncomputed_tokens(self) -> int:
  167. """Return the number of prefill tokens that are not computed."""
  168. # we use `get_len()` which includes prompt_len + output_len instead
  169. # of prompt_len here. This is because during recompute we need to
  170. # prefill for both prompt and output.
  171. return self.get_len() - self.get_num_computed_tokens()
  172. def get_last_token_id(self) -> int:
  173. if not self._output_token_ids:
  174. return self._prompt_token_ids[-1]
  175. return self._output_token_ids[-1]
  176. def get_prompt_token_ids(self) -> Tuple[int, ...]:
  177. return self.prompt_token_ids
  178. def get_output_token_ids(self) -> Tuple[int, ...]:
  179. return self.output_token_ids
  180. @property
  181. def stage(self) -> SequenceStage:
  182. return self._stage
  183. def __repr__(self) -> str:
  184. return (f"SequenceData("
  185. f"prompt_token_ids={self._prompt_token_ids}, "
  186. f"output_token_ids={self._output_token_ids}, "
  187. f"cumulative_logprob={self.cumulative_logprob})")
  188. class Sequence:
  189. """Stores the data, status, and block information of a sequence.
  190. Args:
  191. seq_id: The ID of the sequence.
  192. inputs: The inputs of the sequence.
  193. block_size: The block size of the sequence. Should be the same as the
  194. block size used by the block manager and cache engine.
  195. lora_request: LoRA request.
  196. prompt_adapter_request: Prompt adapter request.
  197. """
  198. def __init__(
  199. self,
  200. seq_id: int,
  201. inputs: "LLMInputs",
  202. block_size: int,
  203. eos_token_id: Optional[int] = None,
  204. lora_request: Optional[LoRARequest] = None,
  205. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  206. ) -> None:
  207. self.seq_id = seq_id
  208. self.inputs = inputs
  209. self.block_size = block_size
  210. self.eos_token_id = eos_token_id
  211. self.lora_request = lora_request
  212. self.prompt_adapter_request = prompt_adapter_request
  213. self.data = SequenceData(self.prompt_token_ids)
  214. self.output_logprobs: SampleLogprobs = []
  215. self.output_text = ""
  216. self.status = SequenceStatus.WAITING
  217. self.stop_reason: Union[int, str, None] = None
  218. # Used for incremental detokenization
  219. self.prefix_offset = 0
  220. self.read_offset = 0
  221. # Input + output tokens
  222. self.tokens: Optional[List[str]] = None
  223. @property
  224. def n_blocks(self) -> int:
  225. return math.ceil(self.get_len() / self.block_size)
  226. @property
  227. def prompt(self) -> Optional[str]:
  228. return self.inputs.get("prompt")
  229. @property
  230. def prompt_token_ids(self) -> List[int]:
  231. return self.inputs["prompt_token_ids"]
  232. @property
  233. def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
  234. return self.inputs.get("multi_modal_data")
  235. @property
  236. def lora_int_id(self) -> int:
  237. return self.lora_request.lora_int_id if self.lora_request else 0
  238. @property
  239. def prompt_adapter_id(self) -> int:
  240. return self.prompt_adapter_request.prompt_adapter_id \
  241. if self.prompt_adapter_request else 0
  242. def get_output_text_to_return(self, buffer_length: int):
  243. # We return the full output text if the sequence is finished.
  244. truncate = buffer_length and not self.is_finished()
  245. return self.output_text[:-buffer_length] if truncate else (
  246. self.output_text)
  247. def hash_of_block(self, logical_idx: int) -> int:
  248. # TODO This can produce incorrect hash when block size > prompt size
  249. # Compute the number of tokens in the sequence
  250. # TODO: The current hashing function is O(L^2). We should optimize
  251. # this in the future.
  252. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  253. hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
  254. return hash((hashed_tokens, self.lora_int_id))
  255. def num_hashed_tokens_of_block(self, logical_idx: int):
  256. return logical_idx * self.block_size + self.block_size
  257. def reset_state_for_recompute(self):
  258. """Reset the sequence states for recomputation."""
  259. self.data.reset_state_for_recompute()
  260. def append_token_id(
  261. self,
  262. token_id: int,
  263. logprobs: Dict[int, Logprob],
  264. ) -> None:
  265. assert token_id in logprobs
  266. self.output_logprobs.append(logprobs)
  267. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  268. def get_len(self) -> int:
  269. return self.data.get_len()
  270. def get_prompt_len(self) -> int:
  271. return self.data.get_prompt_len()
  272. def get_output_len(self) -> int:
  273. return self.data.get_output_len()
  274. def get_token_ids(self) -> List[int]:
  275. return self.data.get_token_ids()
  276. def get_prompt_token_ids(self) -> Tuple[int, ...]:
  277. return self.data.get_prompt_token_ids()
  278. def get_last_token_id(self) -> int:
  279. return self.data.get_last_token_id()
  280. def get_output_token_ids(self) -> Tuple[int, ...]:
  281. return self.data.get_output_token_ids()
  282. def get_cumulative_logprob(self) -> float:
  283. return self.data.cumulative_logprob
  284. def get_beam_search_score(self,
  285. length_penalty: float = 1.0,
  286. seq_len: Optional[int] = None,
  287. eos_token_id: Optional[int] = None) -> float:
  288. """Calculate the beam search score with length penalty.
  289. Adapted from
  290. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  291. """
  292. if seq_len is None:
  293. seq_len = self.get_len()
  294. # NOTE: HF implementation does not count the EOS token
  295. # towards the length, we align with that here for testing.
  296. if (eos_token_id is not None
  297. and self.get_last_token_id() == eos_token_id):
  298. seq_len -= 1
  299. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  300. def is_finished(self) -> bool:
  301. return SequenceStatus.is_finished(self.status)
  302. def fork(self, new_seq_id: int) -> "Sequence":
  303. new_seq = copy.deepcopy(self)
  304. new_seq.seq_id = new_seq_id
  305. return new_seq
  306. def get_num_new_tokens(self) -> int:
  307. """Get the number of new tokens to be computed.
  308. Returns:
  309. The new number of tokens to be computed. I.e., 1 for decode, or
  310. the remaining prompt size for prefill.
  311. """
  312. if self.data.stage == SequenceStage.DECODE:
  313. return 1
  314. return self.data.get_num_uncomputed_tokens()
  315. def is_prefill(self) -> bool:
  316. return self.data.stage == SequenceStage.PREFILL
  317. def __repr__(self) -> str:
  318. return (f"Sequence(seq_id={self.seq_id}, "
  319. f"status={self.status.name}, "
  320. f"num_blocks={self.n_blocks}, ")
  321. @dataclass
  322. class SequenceGroupState:
  323. """Mutable state tied to a specific sequence group"""
  324. # torch.Generator used in seeded sampling
  325. generator: Optional = None # type: ignore
  326. class SequenceGroup:
  327. """A group of sequences that are generated from the same prompt.
  328. Args:
  329. request_id: The ID of the request.
  330. seqs: The list of sequences.
  331. sampling_params: The sampling parameters used to generate the outputs.
  332. arrival_time: The arrival time of the request.
  333. lora_request: LoRA request.
  334. embeddings: The embeddings vectors of the prompt of the sequence group
  335. for an embedding model.
  336. pooling_params: The pooling parameters used to generate the pooling
  337. for an embedding model.
  338. encoder_seq: Optional, the single encoder sequence. Should be None
  339. unless you are working with an encoder/decoder model.
  340. prompt_adapter_request: Prompt adapter request.
  341. """
  342. def __init__(
  343. self,
  344. request_id: str,
  345. seqs: List[Sequence],
  346. arrival_time: float,
  347. sampling_params: Optional[SamplingParams] = None,
  348. lora_request: Optional[LoRARequest] = None,
  349. embeddings: Optional[List[float]] = None,
  350. pooling_params: Optional[PoolingParams] = None,
  351. encoder_seq: Optional[Sequence] = None,
  352. trace_headers: Optional[Dict[str, str]] = None,
  353. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  354. ) -> None:
  355. self.request_id = request_id
  356. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  357. self.sampling_params = sampling_params
  358. self.metrics = RequestMetrics(arrival_time=arrival_time,
  359. last_token_time=arrival_time,
  360. first_scheduled_time=None,
  361. first_token_time=None,
  362. time_in_queue=None)
  363. self.lora_request = lora_request
  364. self.prompt_logprobs: Optional[PromptLogprobs] = None
  365. self.state = SequenceGroupState()
  366. self.embeddings = embeddings
  367. self.pooling_params = pooling_params
  368. self.prompt_adapter_request = prompt_adapter_request
  369. self.encoder_seq = encoder_seq
  370. self.trace_headers = trace_headers
  371. self._first_seq = next(iter(self.seqs_dict.values()))
  372. @property
  373. def prompt(self) -> Optional[str]:
  374. # All sequences in the group should have the same prompt.
  375. # We use the prompt of an arbitrary sequence.
  376. return self._first_seq.prompt
  377. @property
  378. def prompt_token_ids(self) -> List[int]:
  379. # All sequences in the group should have the same prompt.
  380. # We use the prompt of an arbitrary sequence.
  381. return self._first_seq.prompt_token_ids
  382. @property
  383. def multi_modal_data(self) -> "MultiModalDataDict":
  384. # All sequences in the group should have the same multi-modal data.
  385. # We use the multi-modal data of an arbitrary sequence.
  386. return self._first_seq.multi_modal_data
  387. @property
  388. def lora_int_id(self) -> int:
  389. return self.lora_request.lora_int_id if self.lora_request else 0
  390. @property
  391. def prompt_adapter_id(self) -> int:
  392. return self.prompt_adapter_request.prompt_adapter_id \
  393. if self.prompt_adapter_request else 0
  394. @property
  395. def prompt_adapter_num_virtual_tokens(self) -> int:
  396. return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
  397. if self.prompt_adapter_request else 0
  398. def get_last_latency(self, now: float) -> Optional[float]:
  399. """Sets the last token time for Request level timings."""
  400. # If still in prefill phase, raise Error.
  401. if self.is_prefill():
  402. raise ValueError(
  403. "seq_group.get_last_latency() should not be called "
  404. "if the seq_group is in prefill phase.")
  405. # Otherwise return token latency.
  406. latency = now - self.metrics.last_token_time
  407. self.metrics.last_token_time = now
  408. return latency
  409. def maybe_set_first_token_time(self, time: float) -> None:
  410. """Sets the first token time for Request level timings."""
  411. # NOTE: in a case where a sequence_group is swapped and
  412. # recomputed, the time between iterations is counted
  413. # in TPOT, rather than recalculating TTFT (since from the )
  414. # POV of the user, there is simply a long generation delay.
  415. if (self.metrics.first_token_time is None
  416. and self.get_seqs()[0].get_output_len() == 1):
  417. self.metrics.first_token_time = time
  418. def maybe_set_first_scheduled_time(self, time: float) -> None:
  419. """Sets the first scheduled time and time in queue for Request
  420. level timings."""
  421. if self.metrics.first_scheduled_time is None:
  422. self.metrics.first_scheduled_time = time
  423. self.metrics.time_in_queue = time - self.metrics.arrival_time
  424. def set_finished_time(self, time: Optional[float]) -> None:
  425. """Sets the finished time for Request level timings."""
  426. self.metrics.finished_time = time
  427. def get_max_num_running_seqs(self) -> int:
  428. """The maximum number of sequences running in parallel in the remaining
  429. lifetime of the request."""
  430. if self.sampling_params and self.sampling_params.use_beam_search:
  431. # For beam search, maximally there will always be `best_of` beam
  432. # candidates running in the future.
  433. return self.sampling_params.best_of
  434. else:
  435. if (self.sampling_params
  436. and self.sampling_params.best_of > self.num_seqs()):
  437. # At prompt stage, the sequence group is not yet filled up
  438. # and only have one sequence running. However, in the
  439. # generation stage, we will have `best_of` sequences running.
  440. return self.sampling_params.best_of
  441. # At sampling stages, return the number of actual sequences
  442. # that are not finished yet.
  443. return self.num_unfinished_seqs()
  444. def get_seqs(
  445. self,
  446. status: Optional[SequenceStatus] = None,
  447. ) -> List[Sequence]:
  448. return list(self.seqs_dict.values()) if status is None else [
  449. seq for seq in self.seqs_dict.values() if seq.status == status
  450. ]
  451. def is_encoder_decoder(self) -> bool:
  452. return self.encoder_seq is not None
  453. def get_encoder_seq(self) -> Optional[Sequence]:
  454. return self.encoder_seq
  455. def get_unfinished_seqs(self) -> List[Sequence]:
  456. return [
  457. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  458. ]
  459. def get_finished_seqs(self) -> List[Sequence]:
  460. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  461. def update_num_computed_tokens(self, num_new_computed_tokens: int):
  462. """Update number of tokens computed so far."""
  463. for seq in self.seqs_dict.values():
  464. if not seq.is_finished():
  465. seq.data.update_num_computed_tokens(num_new_computed_tokens)
  466. def get_num_uncomputed_tokens(self) -> int:
  467. num_uncomputed_tokens = 0
  468. for seq in self.get_seqs():
  469. if not seq.is_finished():
  470. num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
  471. return num_uncomputed_tokens
  472. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  473. # Optimization. We don't need to call get_seqs if we don't need to
  474. # filter by states.
  475. if status is None:
  476. return len(self.seqs_dict)
  477. return len(self.get_seqs(status))
  478. def num_unfinished_seqs(self) -> int:
  479. return len(self.get_unfinished_seqs())
  480. def num_finished_seqs(self) -> int:
  481. return len(self.get_finished_seqs())
  482. def find(self, seq_id: int) -> Sequence:
  483. if seq_id not in self.seqs_dict:
  484. raise ValueError(f"Sequence {seq_id} not found.")
  485. return self.seqs_dict[seq_id]
  486. def add(self, seq: Sequence) -> None:
  487. if seq.seq_id in self.seqs_dict:
  488. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  489. self.seqs_dict[seq.seq_id] = seq
  490. def remove(self, seq_id: int) -> None:
  491. if seq_id not in self.seqs_dict:
  492. raise ValueError(f"Sequence {seq_id} not found.")
  493. del self.seqs_dict[seq_id]
  494. def is_finished(self) -> bool:
  495. return all(seq.is_finished() for seq in self.get_seqs())
  496. def is_prefill(self) -> bool:
  497. # Every sequence should be in the same stage.
  498. return self.get_seqs()[0].is_prefill()
  499. def __repr__(self) -> str:
  500. return (f"SequenceGroup(request_id={self.request_id}, "
  501. f"sampling_params={self.sampling_params}, "
  502. f"num_seqs={len(self.seqs_dict)})")
  503. class SequenceGroupMetadata:
  504. """Metadata for a sequence group. Used to create `AttentionMetadata`.
  505. Args:
  506. request_id: The ID of the request.
  507. is_prompt: Whether the request is at prompt stage.
  508. seq_data: The sequence data. (Seq id -> sequence data)
  509. sampling_params: The sampling parameters used to generate the outputs.
  510. block_tables: The block tables. (Seq id -> list of physical block
  511. numbers)
  512. do_sample: True if sampling is required. Sampling is not required when
  513. e.g., prefill is chunked, and the current iteration only computes
  514. query tokens for prefill, we don't need sampling.
  515. token_chunk_size: The number of tokens to be processed (per sequence).
  516. None if chunking is not required.
  517. lora_request: LoRA request.
  518. computed_block_nums: The block numbers that are already computed,
  519. used in prefix caching.
  520. state: Internal state tied to this sequence group.
  521. multi_modal_data: Multi modal data.
  522. encoder_seq_data: Optional sequence data for encoder prompt
  523. (SequenceGroup.encoder_seq). Should be None
  524. unless you are working with an encoder/decoder
  525. model.
  526. cross_block_table: Optional cross-attention block table associated
  527. with the encoder prompt
  528. (SequenceGroup.encoder_seq). Should be None
  529. unless you are working with an encoder/decoder
  530. model.
  531. prompt_adapter_request: Prompt Adapter request.
  532. """
  533. def __init__(
  534. self,
  535. request_id: str,
  536. is_prompt: bool,
  537. seq_data: Dict[int, SequenceData],
  538. sampling_params: SamplingParams,
  539. block_tables: Dict[int, List[int]],
  540. do_sample: bool = True,
  541. pooling_params: Optional[PoolingParams] = None,
  542. token_chunk_size: Optional[int] = None,
  543. lora_request: Optional[LoRARequest] = None,
  544. computed_block_nums: Optional[List[int]] = None,
  545. state: Optional[SequenceGroupState] = None,
  546. multi_modal_data: Optional["MultiModalDataDict"] = None,
  547. encoder_seq_data: Optional[SequenceData] = None,
  548. cross_block_table: Optional[List[int]] = None,
  549. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  550. ) -> None:
  551. self.request_id = request_id
  552. self.is_prompt = is_prompt
  553. self.seq_data = seq_data
  554. self.sampling_params = sampling_params
  555. self.block_tables = block_tables
  556. self.pooling_params = pooling_params
  557. self.lora_request = lora_request
  558. self.prompt_adapter_request = prompt_adapter_request
  559. self.computed_block_nums = computed_block_nums
  560. self.multi_modal_data = multi_modal_data
  561. self.state = SequenceGroupState() if state is None else state
  562. self.encoder_seq_data = encoder_seq_data
  563. self.cross_block_table = cross_block_table
  564. self._token_chunk_size = token_chunk_size
  565. self.do_sample = do_sample
  566. # The number of speculative tokens adopted in this request.
  567. # None means specuative decoding is not used.
  568. # Zero means speculative decoding is disabled for some reasons.
  569. # TODO: We should maintain this states out of the sequence group.
  570. self.num_speculative_tokens = None
  571. if self._token_chunk_size is None:
  572. if is_prompt:
  573. self._token_chunk_size = list(seq_data.values())[0].get_len()
  574. else:
  575. self._token_chunk_size = 1
  576. @property
  577. def lora_int_id(self) -> int:
  578. return self.lora_request.lora_int_id if self.lora_request else 0
  579. @property
  580. def prompt_adapter_id(self) -> int:
  581. return self.prompt_adapter_request.prompt_adapter_id \
  582. if self.prompt_adapter_request else 0
  583. @property
  584. def prompt_adapter_num_virtual_tokens(self) -> int:
  585. return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
  586. if self.prompt_adapter_request else 0
  587. @property
  588. def token_chunk_size(self) -> int:
  589. """Return the number of tokens to be processed (chunk size)."""
  590. assert self._token_chunk_size is not None
  591. return self._token_chunk_size
  592. class SequenceOutput:
  593. """The model output associated with a sequence.
  594. Args:
  595. parent_seq_id: The ID of the parent sequence (for forking in beam
  596. search).
  597. output_token: The output token ID.
  598. logprobs: The logprobs of the output token.
  599. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  600. """
  601. def __init__(
  602. self,
  603. parent_seq_id: int,
  604. output_token: int,
  605. logprobs: Dict[int, Logprob],
  606. ) -> None:
  607. self.parent_seq_id = parent_seq_id
  608. self.output_token = output_token
  609. self.logprobs = logprobs
  610. def __repr__(self) -> str:
  611. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  612. f"output_token={self.output_token}, "
  613. f"logprobs={self.logprobs})")
  614. def __eq__(self, other: object) -> bool:
  615. if not isinstance(other, SequenceOutput):
  616. raise NotImplementedError()
  617. equal = (self.parent_seq_id == other.parent_seq_id
  618. and self.output_token == other.output_token)
  619. log_probs_equal = other.logprobs == self.logprobs
  620. return equal and log_probs_equal
  621. class SequenceGroupOutput(ABC):
  622. """The base class for model outputs associated with a sequence group."""
  623. @abstractmethod
  624. def __repr__(self) -> str:
  625. pass
  626. @abstractmethod
  627. def __eq__(self, other: object) -> bool:
  628. pass
  629. class CompletionSequenceGroupOutput(SequenceGroupOutput):
  630. """The model output associated with a completion sequence group."""
  631. def __init__(
  632. self,
  633. samples: List[SequenceOutput],
  634. prompt_logprobs: Optional[PromptLogprobs],
  635. ) -> None:
  636. self.samples = samples
  637. # Prompt logprob for each prompt query token.
  638. self.prompt_logprobs = prompt_logprobs
  639. def __repr__(self) -> str:
  640. return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
  641. f"prompt_logprobs={self.prompt_logprobs})")
  642. def __eq__(self, other: object) -> bool:
  643. if not isinstance(other, CompletionSequenceGroupOutput):
  644. raise NotImplementedError()
  645. return (self.samples == other.samples
  646. and self.prompt_logprobs == other.prompt_logprobs)
  647. class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
  648. """The model output associated with an embedding sequence group."""
  649. def __init__(
  650. self,
  651. embeddings: List[float],
  652. ) -> None:
  653. self.embeddings = embeddings
  654. def __repr__(self) -> str:
  655. return (f"EmbeddingSequenceGroupOutput("
  656. f"embeddings_shape={len(self.embeddings)})")
  657. def __eq__(self, other: object) -> bool:
  658. if not isinstance(other, EmbeddingSequenceGroupOutput):
  659. raise NotImplementedError()
  660. return self.embeddings == other.embeddings
  661. @dataclass
  662. class IntermediateTensors:
  663. """For all pipeline stages except the last, we need to return the hidden
  664. states and residuals to be sent to the next stage. This data structure
  665. contains the hidden states and residuals for a request.
  666. """
  667. tensors: Dict[str, torch.Tensor]
  668. def __getitem__(self, key: Union[str, slice]):
  669. if isinstance(key, str):
  670. return self.tensors[key]
  671. elif isinstance(key, slice):
  672. return self.__class__({k: v[key] for k, v in self.tensors.items()})
  673. def __setitem__(self, key: str, value):
  674. self.tensors[key] = value
  675. def __len__(self):
  676. return len(self.tensors)
  677. def __eq__(self, other: object):
  678. return isinstance(other, self.__class__) and self
  679. def __repr__(self) -> str:
  680. return f"IntermediateTensors(tensors={self.tensors})"
  681. @dataclass
  682. class SamplerOutput:
  683. """For each sequence group, we generate a list of SequenceOutput object,
  684. each of which contains one possible candidate for the next token.
  685. This data structure implements methods, so it can be used like a list, but
  686. also has optional fields for device tensors.
  687. """
  688. outputs: List[CompletionSequenceGroupOutput]
  689. # On-device tensor containing probabilities of each token.
  690. sampled_token_probs: Optional[torch.Tensor] = None
  691. # On-device tensor containing the logprobs of each token.
  692. logprobs: Optional["torch.Tensor"] = None
  693. # On-device tensor containing the sampled token ids.
  694. sampled_token_ids: Optional[torch.Tensor] = None
  695. # Spec decode metrics populated by workers.
  696. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  697. # Optional last hidden states from the model.
  698. hidden_states: Optional[torch.Tensor] = None
  699. def __getitem__(self, idx: int):
  700. return self.outputs[idx]
  701. def __setitem__(self, idx: int, value):
  702. self.outputs[idx] = value
  703. def __len__(self):
  704. return len(self.outputs)
  705. def __eq__(self, other: object):
  706. return isinstance(other,
  707. self.__class__) and self.outputs == other.outputs
  708. def __repr__(self) -> str:
  709. """Show the shape of a tensor instead of its values to reduce noise.
  710. """
  711. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  712. else self.sampled_token_probs.shape)
  713. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  714. self.sampled_token_ids.shape)
  715. return (
  716. f"SamplerOutput(outputs={self.outputs}, "
  717. f"sampled_token_probs={sampled_token_probs_repr}, "
  718. f"sampled_token_ids={sampled_token_ids_repr}, "
  719. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
  720. @dataclass
  721. class PoolerOutput:
  722. """The output from a pooling operation in the embedding model."""
  723. outputs: List[EmbeddingSequenceGroupOutput]
  724. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  725. def __getitem__(self, idx: int):
  726. return self.outputs[idx]
  727. def __setitem__(self, idx: int, value):
  728. self.outputs[idx] = value
  729. def __len__(self):
  730. return len(self.outputs)
  731. def __eq__(self, other: object):
  732. return isinstance(other,
  733. self.__class__) and self.outputs == other.outputs
  734. def get_all_seq_ids(
  735. seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
  736. """Given a list of SequenceGroupMetadata, create a list of all
  737. sequence ids.
  738. """
  739. return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
  740. def get_all_seq_ids_and_request_ids(
  741. seq_group_metadata_list: List[SequenceGroupMetadata]
  742. ) -> Tuple[List[int], Dict[str, Set[int]]]:
  743. """Given a list of SequenceGroupMetadata, create a list of all
  744. sequence ids.
  745. """
  746. seq_ids: List[int] = []
  747. request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
  748. for sg in seq_group_metadata_list:
  749. for seq_id in sg.seq_data:
  750. seq_ids.append(seq_id)
  751. request_id_seq_ids_mapping[sg.request_id].add(seq_id)
  752. return seq_ids, request_id_seq_ids_mapping
  753. class HiddenStates:
  754. """Hidden states corresponding to in-progress sequences.
  755. Used in speculative decoding to pass hidden states from
  756. the target model to the proposer model in the subsequent step.
  757. seq_ids are the sequence ids of each entry of the batch
  758. dimension of the hidden_states tensor"""
  759. def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
  760. hidden_states: torch.Tensor):
  761. assert len(seq_group_metadata_list) == len(hidden_states)
  762. self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
  763. self.hidden_states: torch.Tensor = hidden_states
  764. def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
  765. hidden_states: torch.Tensor) -> None:
  766. """Update hidden states from target model invocation."""
  767. assert len(seq_group_metadata_list) == len(hidden_states)
  768. self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
  769. self.hidden_states = torch.cat([self.hidden_states, hidden_states])
  770. def prune(self,
  771. seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
  772. """Prune to provided list of sequence ids."""
  773. seq_ids = get_all_seq_ids(seq_group_metadata_list)
  774. if seq_ids != self.seq_ids:
  775. # Batch contents changed - prune removed sequences.
  776. index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
  777. self.hidden_states = self.hidden_states[index]
  778. self.seq_ids = seq_ids
  779. @dataclass
  780. class ExecuteModelRequest:
  781. """The model execution request, containing CPU metadata only. The LLM
  782. engine should create an instance of this class for each request batch."""
  783. # The sequence group metadata list.
  784. seq_group_metadata_list: List[SequenceGroupMetadata]
  785. # Blocks to swap in. List of CPU -> GPU block number.
  786. blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
  787. # Blocks to swap out. List of GPU -> CPU block number.
  788. blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
  789. # Blocks to copy. Source to dest block.
  790. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
  791. # Virtual engine ID for pipeline parallel.
  792. virtual_engine: int = 0
  793. # The number of slots for lookahead decoding.
  794. num_lookahead_slots: int = 0
  795. # The number of requests in the running queue.
  796. running_queue_size: int = 0
  797. # Optional hidden states from prior step.
  798. previous_hidden_states: Optional[HiddenStates] = None
  799. # The number of forward steps to run.
  800. num_steps: int = 1
  801. # Finished request ids since last step.
  802. finished_requests_ids: List[str] = field(default_factory=list)
  803. def clone(
  804. self, seq_group_metadata_list: List[SequenceGroupMetadata]
  805. ) -> "ExecuteModelRequest":
  806. """Clone the request with a new sequence group metadata list."""
  807. return ExecuteModelRequest(
  808. seq_group_metadata_list=seq_group_metadata_list,
  809. blocks_to_swap_in=self.blocks_to_swap_in.copy(),
  810. blocks_to_swap_out=self.blocks_to_swap_out.copy(),
  811. blocks_to_copy=self.blocks_to_copy.copy(),
  812. virtual_engine=self.virtual_engine,
  813. num_lookahead_slots=self.num_lookahead_slots,
  814. running_queue_size=self.running_queue_size,
  815. previous_hidden_states=self.previous_hidden_states,
  816. num_steps=self.num_steps,
  817. finished_requests_ids=self.finished_requests_ids,
  818. )