1
0

sequence.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """Sequence and its related classes."""
  2. import copy
  3. import enum
  4. from dataclasses import dataclass
  5. from typing import Dict, List, Optional, Union, TYPE_CHECKING
  6. from aphrodite.common.block import LogicalTokenBlock
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.lora.request import LoRARequest
  9. if TYPE_CHECKING:
  10. import torch
  11. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  12. @dataclass
  13. class Logprob:
  14. """Infos for supporting OpenAI compatible logprobs."""
  15. logprob: float
  16. decoded_token: Optional[str] = None
  17. PromptLogprobs = List[Optional[Dict[int, Logprob]]]
  18. SampleLogprobs = List[Dict[int, Logprob]]
  19. class SequenceStatus(enum.Enum):
  20. """Status of a sequence."""
  21. WAITING = enum.auto()
  22. RUNNING = enum.auto()
  23. SWAPPED = enum.auto()
  24. FINISHED_STOPPED = enum.auto()
  25. FINISHED_LENGTH_CAPPED = enum.auto()
  26. FINISHED_ABORTED = enum.auto()
  27. FINISHED_IGNORED = enum.auto()
  28. @staticmethod
  29. def is_finished(status: "SequenceStatus") -> bool:
  30. return status in [
  31. SequenceStatus.FINISHED_STOPPED,
  32. SequenceStatus.FINISHED_LENGTH_CAPPED,
  33. SequenceStatus.FINISHED_ABORTED,
  34. SequenceStatus.FINISHED_IGNORED,
  35. ]
  36. @staticmethod
  37. def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
  38. if status == SequenceStatus.FINISHED_STOPPED:
  39. finish_reason = "stop"
  40. elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
  41. finish_reason = "length"
  42. elif status == SequenceStatus.FINISHED_ABORTED:
  43. finish_reason = "abort"
  44. elif status == SequenceStatus.FINISHED_IGNORED:
  45. # The ignored sequences are the sequences whose prompt lengths
  46. # are longer than the model's length cap. Therefore, the stop
  47. # reason should also be "length" as in OpenAI API.
  48. finish_reason = "length"
  49. else:
  50. finish_reason = None
  51. return finish_reason
  52. @dataclass
  53. class RequestMetrics:
  54. """Metrics associated with a request.
  55. Args:
  56. arrival_time: The time when the request arrived.
  57. first_scheduled_time: The time when the request was first scheduled.
  58. first_token_time: The time when the first token was generated.
  59. time_in_queue: The time the request spent in the queue.
  60. finished_time: The time when the request was finished.
  61. """
  62. arrival_time: float
  63. last_token_time: float
  64. first_scheduled_time: Optional[float]
  65. first_token_time: Optional[float]
  66. time_in_queue: Optional[float]
  67. finished_time: Optional[float] = None
  68. class SequenceData:
  69. """Data associated with a sequence.
  70. Args:
  71. prompt_token_ids: The token IDs of the prompt.
  72. output_token_ids: The token IDs of the output. Set to an empty list if
  73. None.
  74. Attributes:
  75. prompt_token_ids: The token IDs of the prompt.
  76. output_token_ids: The token IDs of the output.
  77. cumulative_logprob: The cumulative log probability of the output.
  78. """
  79. def __init__(
  80. self,
  81. prompt_token_ids: List[int],
  82. output_token_ids: Optional[List[int]] = None,
  83. ) -> None:
  84. if output_token_ids is None:
  85. output_token_ids = []
  86. self.prompt_token_ids = prompt_token_ids
  87. self.output_token_ids = output_token_ids
  88. self.cumulative_logprob = 0.0
  89. def append_token_id(self, token_id: int, logprob: float) -> None:
  90. self.output_token_ids.append(token_id)
  91. self.cumulative_logprob += logprob
  92. def get_len(self) -> int:
  93. return len(self.output_token_ids) + len(self.prompt_token_ids)
  94. def get_prompt_len(self) -> int:
  95. return len(self.prompt_token_ids)
  96. def get_output_len(self) -> int:
  97. return len(self.output_token_ids)
  98. def get_token_ids(self) -> List[int]:
  99. return self.prompt_token_ids + self.output_token_ids
  100. def get_last_token_id(self) -> int:
  101. if not self.output_token_ids:
  102. return self.prompt_token_ids[-1]
  103. return self.output_token_ids[-1]
  104. def get_prompt_token_ids(self) -> int:
  105. return self.prompt_token_ids
  106. def get_output_token_ids(self) -> int:
  107. return self.output_token_ids
  108. def __repr__(self) -> str:
  109. return (f"SequenceData("
  110. f"prompt_token_ids={self.prompt_token_ids}, "
  111. f"output_token_ids={self.output_token_ids}, "
  112. f"cumulative_logprob={self.cumulative_logprob})")
  113. class Sequence:
  114. """Stores the data, status, and block information of a sequence.
  115. Args:
  116. seq_id: The ID of the sequence.
  117. prompt: The prompt of the sequence.
  118. prompt_token_ids: The token IDs of the prompt.
  119. block_size: The block size of the sequence. Should be the same as the
  120. block size used by the block manager and cache engine.
  121. lora_request: LoRA request.
  122. """
  123. def __init__(
  124. self,
  125. seq_id: int,
  126. prompt: str,
  127. prompt_token_ids: List[int],
  128. block_size: int,
  129. eos_token_id: Optional[int] = None,
  130. lora_request: Optional[LoRARequest] = None,
  131. ) -> None:
  132. self.seq_id = seq_id
  133. self.prompt = prompt
  134. self.block_size = block_size
  135. self.eos_token_id = eos_token_id
  136. self.lora_request = lora_request
  137. self.data = SequenceData(prompt_token_ids)
  138. self.output_logprobs: SampleLogprobs = []
  139. self.output_text = ""
  140. self.logical_token_blocks: List[LogicalTokenBlock] = []
  141. # Initialize the logical token blocks with the prompt token ids.
  142. self._append_tokens_to_blocks(prompt_token_ids)
  143. self.status = SequenceStatus.WAITING
  144. # Used for incremental detokenization
  145. self.prefix_offset = 0
  146. self.read_offset = 0
  147. # Input + output tokens
  148. self.tokens: Optional[List[str]] = None
  149. self.persistent_data = {}
  150. @property
  151. def lora_int_id(self) -> int:
  152. return self.lora_request.lora_int_id if self.lora_request else 0
  153. def hash_of_block(self, logical_idx: int) -> int:
  154. # Compute the number of tokens in the sequence
  155. # TODO: The current hashing function is O(L^2). We should optimize
  156. # this in the future.
  157. num_tokens = self.num_hashed_tokens_of_block(logical_idx)
  158. return hash(
  159. (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
  160. def num_hashed_tokens_of_block(self, logical_idx: int):
  161. return logical_idx * self.block_size + self.block_size
  162. def _append_logical_block(self) -> None:
  163. block = LogicalTokenBlock(
  164. block_number=len(self.logical_token_blocks),
  165. block_size=self.block_size,
  166. )
  167. self.logical_token_blocks.append(block)
  168. def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
  169. cursor = 0
  170. while cursor < len(token_ids):
  171. if not self.logical_token_blocks:
  172. self._append_logical_block()
  173. last_block = self.logical_token_blocks[-1]
  174. if last_block.is_full():
  175. self._append_logical_block()
  176. last_block = self.logical_token_blocks[-1]
  177. num_empty_slots = last_block.get_num_empty_slots()
  178. last_block.append_tokens(token_ids[cursor:cursor +
  179. num_empty_slots])
  180. cursor += num_empty_slots
  181. def append_token_id(
  182. self,
  183. token_id: int,
  184. logprobs: Dict[int, Logprob],
  185. ) -> None:
  186. assert token_id in logprobs
  187. self._append_tokens_to_blocks([token_id])
  188. self.output_logprobs.append(logprobs)
  189. self.data.append_token_id(token_id, logprobs[token_id].logprob)
  190. def get_len(self) -> int:
  191. return self.data.get_len()
  192. def get_prompt_len(self) -> int:
  193. return self.data.get_prompt_len()
  194. def get_output_len(self) -> int:
  195. return self.data.get_output_len()
  196. def get_token_ids(self) -> List[int]:
  197. return self.data.get_token_ids()
  198. def get_last_token_id(self) -> int:
  199. return self.data.get_last_token_id()
  200. def get_output_token_ids(self) -> List[int]:
  201. return self.data.output_token_ids
  202. def get_cumulative_logprob(self) -> float:
  203. return self.data.cumulative_logprob
  204. def get_beam_search_score(
  205. self,
  206. length_penalty: float = 1.0,
  207. seq_len: Optional[int] = None,
  208. eos_token_id: Optional[int] = None,
  209. ) -> float:
  210. """Calculate the beam search score with length penalty.
  211. Adapted from
  212. https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
  213. """
  214. if seq_len is None:
  215. seq_len = self.get_len()
  216. # NOTE: HF implementation does not count the EOS token
  217. # towards the length, we align with that here for testing.
  218. if (eos_token_id is not None
  219. and self.get_last_token_id() == eos_token_id):
  220. seq_len -= 1
  221. return self.get_cumulative_logprob() / (seq_len**length_penalty)
  222. def is_finished(self) -> bool:
  223. return SequenceStatus.is_finished(self.status)
  224. def fork(self, new_seq_id: int) -> "Sequence":
  225. new_seq = copy.deepcopy(self)
  226. new_seq.seq_id = new_seq_id
  227. return new_seq
  228. def __repr__(self) -> str:
  229. return (f"Sequence(seq_id={self.seq_id}, "
  230. f"status={self.status.name}, "
  231. f"num_blocks={len(self.logical_token_blocks)})")
  232. @dataclass
  233. class SequenceGroupState:
  234. """Mutable state tied to a specific sequence group"""
  235. # torch.Generator used in seeded sampling
  236. generator: Optional = None
  237. class SequenceGroup:
  238. """A group of sequences that are generated from the same prompt.
  239. Args:
  240. request_id: The ID of the request.
  241. seqs: The list of sequences.
  242. sampling_params: The sampling parameters used to generate the outputs.
  243. arrival_time: The arrival time of the request.
  244. lora_request: LoRA request.
  245. """
  246. def __init__(
  247. self,
  248. request_id: str,
  249. seqs: List[Sequence],
  250. sampling_params: SamplingParams,
  251. arrival_time: float,
  252. lora_request: Optional[LoRARequest] = None,
  253. ) -> None:
  254. self.request_id = request_id
  255. self.seqs_dict = {seq.seq_id: seq for seq in seqs}
  256. self.sampling_params = sampling_params
  257. self.metrics = RequestMetrics(
  258. arrival_time=arrival_time,
  259. last_token_time=arrival_time,
  260. first_scheduled_time=None,
  261. first_token_time=None,
  262. time_in_queue=None,
  263. )
  264. self.lora_request = lora_request
  265. self.prompt_logprobs: Optional[PromptLogprobs] = None
  266. self.state = SequenceGroupState()
  267. @property
  268. def prompt(self) -> str:
  269. # All sequences in the group should have the same prompt.
  270. # We use the prompt of an arbitrary sequence.
  271. return next(iter(self.seqs_dict.values())).prompt
  272. @property
  273. def prompt_token_ids(self) -> List[int]:
  274. # All sequences in the group should have the same prompt.
  275. # We use the prompt of an arbitrary sequence.
  276. return next(iter(self.seqs_dict.values())).data.prompt_token_ids
  277. @property
  278. def lora_int_id(self) -> int:
  279. return self.lora_request.lora_int_id if self.lora_request else 0
  280. def get_last_latency(self, now: float) -> float:
  281. """Gets last token latency for Request level timings."""
  282. latency = now - self.metrics.last_token_time
  283. self.metrics.last_token_time = now
  284. return latency
  285. def maybe_set_first_token_time(self, time: float) -> None:
  286. """Sets the first token time for Request level timings."""
  287. if self.metrics.first_token_time is None:
  288. self.metrics.first_token_time = time
  289. def maybe_set_first_scheduled_time(self, time: float) -> None:
  290. """Sets the first scheduled time and time in queue for Request level
  291. timings."""
  292. if self.metrics.first_scheduled_time is None:
  293. self.metrics.first_scheduled_time = time
  294. self.metrics.time_in_queue = time - self.metrics.arrival_time
  295. def set_finished_time(self, time: Optional[float]) -> None:
  296. """Sets the finished time for Request level timings."""
  297. self.metrics.finished_time = time
  298. def get_max_num_running_seqs(self) -> int:
  299. """The maximum number of sequences running in parallel in the remaining
  300. lifetime of the request."""
  301. if self.sampling_params.use_beam_search:
  302. # For beam search, maximally there will always be `best_of` beam
  303. # candidates running in the future.
  304. return self.sampling_params.best_of
  305. else:
  306. if self.sampling_params.best_of > self.num_seqs():
  307. # At prompt stage, the sequence group is not yet filled up
  308. # and only have one sequence running. However, in the
  309. # generation stage, we will have `best_of` sequences running.
  310. return self.sampling_params.best_of
  311. # At sampling stages, return the number of actual sequences
  312. # that are not finished yet.
  313. return self.num_unfinished_seqs()
  314. def get_seqs(
  315. self,
  316. status: Optional[SequenceStatus] = None,
  317. ) -> List[Sequence]:
  318. return (list(self.seqs_dict.values()) if status is None else [
  319. seq for seq in self.seqs_dict.values() if seq.status == status
  320. ])
  321. def get_unfinished_seqs(self) -> List[Sequence]:
  322. return [
  323. seq for seq in self.seqs_dict.values() if not seq.is_finished()
  324. ]
  325. def get_finished_seqs(self) -> List[Sequence]:
  326. return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
  327. def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
  328. return len(self.get_seqs(status))
  329. def num_unfinished_seqs(self) -> int:
  330. return len(self.get_unfinished_seqs())
  331. def num_finished_seqs(self) -> int:
  332. return len(self.get_finished_seqs())
  333. def find(self, seq_id: int) -> Sequence:
  334. if seq_id not in self.seqs_dict:
  335. raise ValueError(f"Sequence {seq_id} not found.")
  336. return self.seqs_dict[seq_id]
  337. def add(self, seq: Sequence) -> None:
  338. if seq.seq_id in self.seqs_dict:
  339. raise ValueError(f"Sequence {seq.seq_id} already exists.")
  340. self.seqs_dict[seq.seq_id] = seq
  341. def remove(self, seq_id: int) -> None:
  342. if seq_id not in self.seqs_dict:
  343. raise ValueError(f"Sequence {seq_id} not found.")
  344. del self.seqs_dict[seq_id]
  345. def is_finished(self) -> bool:
  346. return all(seq.is_finished() for seq in self.get_seqs())
  347. def __repr__(self) -> str:
  348. return (f"SequenceGroup(request_id={self.request_id}, "
  349. f"sampling_params={self.sampling_params}, "
  350. f"num_seqs={len(self.seqs_dict)})")
  351. class SequenceGroupMetadata:
  352. """Metadata for a sequence group. Used to create `InputMetadata`.
  353. Args:
  354. request_id: The ID of the request.
  355. is_prompt: Whether the request is at prompt stage.
  356. seq_data: The sequence data. (Seq id -> sequence data)
  357. sampling_params: The sampling parameters used to generate the outputs.
  358. block_tables: The block tables. (Seq id -> list of physical block
  359. numbers)
  360. state: Internal state tied to this sequence group.
  361. lora_request: LoRA request.
  362. persistent_data: The persistent data of the sequence group.
  363. """
  364. def __init__(
  365. self,
  366. request_id: str,
  367. is_prompt: bool,
  368. seq_data: Dict[int, SequenceData],
  369. sampling_params: SamplingParams,
  370. block_tables: Dict[int, List[int]],
  371. persistent_data: Dict[int, dict],
  372. lora_request: Optional[LoRARequest] = None,
  373. computed_block_nums: Optional[List[int]] = None,
  374. state: Optional[SequenceGroupState] = None,
  375. ) -> None:
  376. self.request_id = request_id
  377. self.is_prompt = is_prompt
  378. self.seq_data = seq_data
  379. self.sampling_params = sampling_params
  380. self.block_tables = block_tables
  381. self.persistent_data = persistent_data
  382. self.lora_request = lora_request
  383. self.computed_block_nums = computed_block_nums
  384. self.state = SequenceGroupState() if state is None else state
  385. @property
  386. def lora_int_id(self) -> int:
  387. return self.lora_request.lora_int_id if self.lora_request else 0
  388. class SequenceOutput:
  389. """The model output associated with a sequence.
  390. Args:
  391. parent_seq_id: The ID of the parent sequence (for forking in beam
  392. search).
  393. output_token: The output token ID.
  394. logprobs: The logprobs of the output token.
  395. (Token id -> logP(x_i+1 | x_0, ..., x_i))
  396. persistent_data: The persistent data of the sequence.
  397. """
  398. def __init__(
  399. self,
  400. parent_seq_id: int,
  401. output_token: int,
  402. logprobs: Dict[int, Logprob],
  403. persistent_data: dict,
  404. ) -> None:
  405. self.parent_seq_id = parent_seq_id
  406. self.output_token = output_token
  407. self.logprobs = logprobs
  408. self.persistent_data = persistent_data
  409. def __repr__(self) -> str:
  410. return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
  411. f"output_token={self.output_token}, "
  412. f"logprobs={self.logprobs}, "
  413. f"persistent_data={self.persistent_data})")
  414. def __eq__(self, other: object) -> bool:
  415. if not isinstance(other, SequenceOutput):
  416. raise NotImplementedError()
  417. equal = (self.parent_seq_id == other.parent_seq_id
  418. and self.output_token == other.output_token)
  419. log_probs_equal = other.logprobs == self.logprobs
  420. return equal and log_probs_equal
  421. class SequenceGroupOutput:
  422. """The model output associated with a sequence group."""
  423. def __init__(
  424. self,
  425. samples: List[SequenceOutput],
  426. prompt_logprobs: Optional[PromptLogprobs],
  427. ) -> None:
  428. self.samples = samples
  429. self.prompt_logprobs = prompt_logprobs
  430. def __repr__(self) -> str:
  431. return (f"SequenceGroupOutput(samples={self.samples}, "
  432. f"prompt_logprobs={self.prompt_logprobs})")
  433. def __eq__(self, other: object) -> bool:
  434. if not isinstance(other, SequenceGroupOutput):
  435. raise NotImplementedError()
  436. return (self.samples == other.samples
  437. and self.prompt_logprobs == other.prompt_logprobs)
  438. @dataclass
  439. class SamplerOutput:
  440. """For each sequence group, we generate a list of SequenceOutput object,
  441. each of which contains one possible candidate for the next token.
  442. This datastructure implements methods so it can be used like a list, but
  443. also has optional fields for device tensors.
  444. """
  445. outputs: List[SequenceGroupOutput]
  446. # On-device tensor containing probabilities of each token.
  447. sampled_token_probs: Optional["torch.Tensor"] = None
  448. # On-device tensor containing the sampled token ids.
  449. sampled_token_ids: Optional["torch.Tensor"] = None
  450. # Spec decode metrics populated by workers.
  451. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
  452. def __getitem__(self, idx: int):
  453. return self.outputs[idx]
  454. def __setitem__(self, idx: int, value):
  455. self.outputs[idx] = value
  456. def __len__(self):
  457. return len(self.outputs)
  458. def __eq__(self, other: object):
  459. return (isinstance(other, self.__class__)
  460. and self.outputs == other.outputs)