12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322 |
- """Sequence and its related classes."""
- import copy
- import enum
- from abc import ABC, abstractmethod
- from array import array
- from collections import defaultdict
- from dataclasses import dataclass
- from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
- Tuple, Union, cast)
- import msgspec
- import torch
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
- from aphrodite.lora.request import LoRARequest
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
- if TYPE_CHECKING:
- from aphrodite.inputs import LLMInputs
- from aphrodite.multimodal.base import MultiModalDataDict
- APHRODITE_TOKEN_ID_ARRAY_TYPE = "l"
- # We use dataclass for now because it is used for
- # openai server output, and msgspec is not serializable.
- # TODO: Fix it.
- @dataclass
- class Logprob:
- """Infos for supporting OpenAI compatible logprobs and token ranks.
- Attributes:
- logprob: The logprob of chosen token
- rank: The vocab rank of chosen token (>=1)
- decoded_token: The decoded chosen token index
- """
- logprob: float
- rank: Optional[int] = None
- decoded_token: Optional[str] = None
- # {token_id -> logprob} per each sequence group. None if the corresponding
- # sequence group doesn't require prompt logprob.
- PromptLogprobs = List[Optional[Dict[int, Logprob]]]
- # {token_id -> logprob} for each sequence group.
- SampleLogprobs = List[Dict[int, Logprob]]
- class SequenceStatus(enum.IntEnum):
- """Status of a sequence."""
- WAITING = 0
- RUNNING = 1
- SWAPPED = 2
- # Note: anything after SWAPPED (2) will be considered
- # as a finished status.
- FINISHED_STOPPED = 3
- FINISHED_LENGTH_CAPPED = 4
- FINISHED_ABORTED = 5
- FINISHED_IGNORED = 6
- @staticmethod
- def is_finished(status: "SequenceStatus") -> bool:
- return status > SequenceStatus.SWAPPED
- @staticmethod
- def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
- if status == SequenceStatus.FINISHED_STOPPED:
- finish_reason = "stop"
- elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
- finish_reason = "length"
- elif status == SequenceStatus.FINISHED_ABORTED:
- finish_reason = "abort"
- elif status == SequenceStatus.FINISHED_IGNORED:
- # The ignored sequences are the sequences whose prompt lengths
- # are longer than the model's length cap. Therefore, the stop
- # reason should also be "length" as in OpenAI API.
- finish_reason = "length"
- else:
- finish_reason = None
- return finish_reason
- class SequenceStage(enum.Enum):
- PREFILL = enum.auto()
- DECODE = enum.auto()
- @dataclass
- class RequestMetrics:
- """Metrics associated with a request.
- Attributes:
- arrival_time: The time when the request arrived.
- first_scheduled_time: The time when the request was first scheduled.
- first_token_time: The time when the first token was generated.
- time_in_queue: The time the request spent in the queue.
- finished_time: The time when the request was finished.
- """
- arrival_time: float
- last_token_time: float
- first_scheduled_time: Optional[float]
- first_token_time: Optional[float]
- time_in_queue: Optional[float]
- finished_time: Optional[float] = None
- class SequenceDataDelta(
- msgspec.Struct,
- array_like=True, # type: ignore[call-arg]
- omit_defaults=True): # type: ignore[call-arg]
- """Delta SequenceData to send to workers per step."""
- # A new token to be appended to existing SequenceData.
- new_output_token_ids: List[int]
- # Overwriting existing `cumulative_logprob`
- new_cumulative_logprob: float
- # Overwriting existing `num_computed_tokens`.
- new_num_computed_tokens: int
- # Overwriting existing `stage`.
- new_stage: SequenceStage
- class SequenceData(msgspec.Struct,
- omit_defaults=True): # type: ignore[call-arg]
- """Data associated with a sequence.
- Args:
- prompt_token_ids: The token IDs of the prompt.
- output_token_ids: The token IDs of the output. Set to an empty list if
- None.
- Attributes:
- prompt_token_ids: The token IDs of the prompt.
- output_token_ids: The token IDs of the output.
- cumulative_logprob: The cumulative log probability of the output.
- """
- # NOTE: we cannot use Union[List, array] because msgspec cannot support
- # union of 2 list types.
- _prompt_token_ids: array
- _output_token_ids: array = msgspec.field(
- default_factory=lambda: array(APHRODITE_TOKEN_ID_ARRAY_TYPE, []))
- ### The below fields should not be passed as an argument ###
- _cumulative_logprob: float = 0.0
- _prompt_token_ids_tuple: Tuple[int,
- ...] = msgspec.field(default_factory=tuple)
- # The number of tokens that are computed (that run against the model).
- _num_computed_tokens: int = 0
- _stage: SequenceStage = SequenceStage.PREFILL
- _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
- # It is used to get delta input. It is reset when `get_delta_and_reset`
- # is called.
- _new_appended_tokens: List[int] = msgspec.field(default_factory=list)
- def __post_init__(self) -> None:
- assert self._prompt_token_ids.typecode == "l"
- assert self._output_token_ids.typecode == "l"
- self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
- self._prompt_token_ids)
- self._update_cached_all_tokens()
- def _update_cached_all_tokens(self):
- assert isinstance(self._prompt_token_ids, array)
- assert isinstance(self._output_token_ids, array)
- self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
- self._output_token_ids)
- @property
- def cumulative_logprob(self) -> float:
- return self._cumulative_logprob
- @property
- def prompt_token_ids(self) -> Tuple[int, ...]:
- return self._prompt_token_ids_tuple
- @prompt_token_ids.setter
- def prompt_token_ids(self, new_prompt_token_ids) -> None:
- raise NotImplementedError
- @property
- def prompt_token_ids_array(self) -> array:
- """Return the prompt token ids in array type.
- Note that the array is in "I" type, and it is not compatible
- with torch.long (2 bytes vs 4 bytes). So beware of the usage.
- """
- return self._prompt_token_ids
- @property
- def output_token_ids(self) -> Tuple[int, ...]:
- return tuple(self._output_token_ids)
- @output_token_ids.setter
- def output_token_ids(self, new_output_token_ids: List[int]) -> None:
- self._output_token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- new_output_token_ids)
- self._update_cached_all_tokens()
- @property
- def output_token_ids_array(self) -> array:
- """Return the prompt token ids in array type.
- Note that the array is in "I" type, and it is not compatible
- with torch.long (2 bytes vs 4 bytes). So beware of the usage.
- """
- assert isinstance(self._output_token_ids, array)
- return self._output_token_ids
- def append_token_id(self, token_id: int, logprob: float) -> None:
- self._output_token_ids.append(token_id)
- self._new_appended_tokens.append(token_id)
- self._cached_all_token_ids.append(token_id)
- self._cumulative_logprob += logprob
- def get_len(self) -> int:
- return len(self._output_token_ids) + len(self._prompt_token_ids)
- def get_prompt_len(self) -> int:
- return len(self._prompt_token_ids)
- def get_output_len(self) -> int:
- return len(self._output_token_ids)
- def get_token_ids(self) -> List[int]:
- return self._cached_all_token_ids
- def get_prefix_token_ids(
- self, num_tokens: int
- ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
- """Get prefix tokens, and make the return value hashable"""
- prompt_length = self.get_prompt_len()
- if num_tokens > prompt_length:
- return (self._prompt_token_ids_tuple,
- tuple(self._output_token_ids[:num_tokens - prompt_length]))
- else:
- return (self._prompt_token_ids_tuple[:num_tokens], None)
- def get_num_computed_tokens(self) -> int:
- """Return the number of prefill tokens that are already computed."""
- return self._num_computed_tokens
- def update_num_computed_tokens(self, num_new_computed_tokens: int):
- """Update number of tokens computed so far."""
- self._num_computed_tokens += num_new_computed_tokens
- assert self._num_computed_tokens <= self.get_len(), (
- self._num_computed_tokens, self.get_len())
- # If all tokens are computed, it means it is in decoding phase.
- if self.get_num_uncomputed_tokens() == 0:
- self._stage = SequenceStage.DECODE
- def reset_state_for_recompute(self) -> None:
- """Reset the number of computed tokens from this sequence. It is
- supposed to be called when a sequence needs to be started from
- the beginning again (e.g., sequence is preempted).
- """
- self._num_computed_tokens = 0
- self._stage = SequenceStage.PREFILL
- self._new_appended_tokens = []
- def get_num_uncomputed_tokens(self) -> int:
- """Return the number of prefill tokens that are not computed."""
- # we use `get_len()` which includes prompt_len + output_len instead
- # of prompt_len here. This is because during recompute we need to
- # prefill for both prompt and output.
- return self.get_len() - self.get_num_computed_tokens()
- def get_last_token_id(self) -> int:
- if not self._output_token_ids:
- return self._prompt_token_ids[-1]
- return self._output_token_ids[-1]
- def get_prompt_token_ids(self) -> Tuple[int, ...]:
- return self.prompt_token_ids
- def get_output_token_ids(self) -> Tuple[int, ...]:
- return self.output_token_ids
- def get_delta_and_reset(self) -> SequenceDataDelta:
- delta = SequenceDataDelta(self._new_appended_tokens,
- self._cumulative_logprob,
- self.get_num_computed_tokens(), self.stage)
- # Reset delta state.
- self._new_appended_tokens = []
- return delta
- def apply_delta(self, delta: SequenceDataDelta):
- self._num_computed_tokens = delta.new_num_computed_tokens
- self._cumulative_logprob = delta.new_cumulative_logprob
- self._stage = delta.new_stage
- self._output_token_ids.extend(delta.new_output_token_ids)
- self._cached_all_token_ids.extend(delta.new_output_token_ids)
- @property
- def stage(self) -> SequenceStage:
- return self._stage
- def __repr__(self) -> str:
- return (f"SequenceData("
- f"prompt_token_ids={self._prompt_token_ids}, "
- f"output_token_ids={self.output_token_ids}, "
- f"cumulative_logprob={self.cumulative_logprob}, "
- f"get_num_computed_tokens={self.get_num_computed_tokens()}")
- class Sequence:
- """Stores the data, status, and block information of a sequence.
- The sequence is constructed from the LLMInputs instance passed
- in through the `inputs` constructor argument.
- For encoder/decoder models, LLMInputs encapsulates both a
- decoder and encoder prompt, creating an ambiguity about which
- prompt to construct the sequence from. The `from_decoder_prompt`
- constructor argument signals whether to construct the Sequence
- from the LLMInputs decoder prompt, or encoder prompt.
- Args:
- seq_id: The ID of the sequence.
- inputs: The inputs of the sequence.
- block_size: The block size of the sequence. Should be the same as the
- block size used by the block manager and cache engine.
- eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
- lora_request: LoRA request.
- prompt_adapter_request: Prompt Adapter request.
- from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
- (True) or encoder prompt (False.) Must be True
- for decoder-only model.
- """
- def __init__(
- self,
- seq_id: int,
- inputs: "LLMInputs",
- block_size: int,
- eos_token_id: Optional[int] = None,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- from_decoder_prompt: bool = True,
- ) -> None:
- self.seq_id = seq_id
- self.inputs = inputs
- self.block_size = block_size
- self.eos_token_id = eos_token_id
- self.lora_request = lora_request
- self.prompt_adapter_request = prompt_adapter_request
- self.from_decoder_prompt = from_decoder_prompt
- self._prompt: Optional[str] = None
- self._prompt_token_ids: Optional[List[int]] = None
- # For decoder-only models, a Sequence is constructed
- # from an LLMInputs instance (the `inputs` arg.)
- #
- # For encoder/decoder models the same `inputs`
- # instance could be utilized to construct either an
- # encoder sequence or a decoder sequence, because
- # `LLMInputs` has both decoder- and encoder-oriented
- # member variables (i.e. it encapsulates both an encoder
- # and a decoder prompt.) The decision of which type of sequence
- # to generate is determined by the `from_decoder_prompt` argument.
- #
- # When constructing a encoder sequence
- # (`from_decoder_prompt` False) it matters that
- # the `LLMInputs` instance stored in `inputs` is valid
- # in the sense that its encoder-related member variables are
- # populated; below, an exception is raised if this is
- # not the case.
- #
- # When constructing a decoder sequence (`from_decoder_prompt` True)
- # it does not matter whether `inputs` has its encoder-related
- # member variables populated.
- if not (from_decoder_prompt
- or is_valid_encoder_decoder_llm_inputs(inputs)):
- raise ValueError("Cannot extract encoder input prompt from "
- f"invalid input {inputs}; did you forget the "
- "encoder input prompt fields?")
- self.data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
- self.output_logprobs: SampleLogprobs = []
- self.output_text = ""
- self.status = SequenceStatus.WAITING
- self.stop_reason: Union[int, str, None] = None
- # Used for incremental detokenization
- self.prefix_offset = 0
- self.read_offset = 0
- # Input + output tokens
- self.tokens: Optional[List[str]] = None
- @property
- def n_blocks(self) -> int:
- return (self.get_len() + self.block_size - 1) // self.block_size
- @property
- def prompt(self) -> Optional[str]:
- if self._prompt is not None:
- # Reuse precomputed prompt string
- return self._prompt
- # Select decoder or encoder input prompt str,
- # as appropriate
- prompt_key: str = ("prompt"
- if self.from_decoder_prompt else "encoder_prompt")
- # Cache prompt
- self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
- return self._prompt
- @property
- def prompt_token_ids(self) -> List[int]:
- if self._prompt_token_ids is not None:
- # Reuse precomputed prompt token ids
- return self._prompt_token_ids
- # Select decoder or encoder input prompt
- # token ids, as appropriate
- prompt_token_ids_key: str = ("prompt_token_ids"
- if self.from_decoder_prompt else
- "encoder_prompt_token_ids")
- # Cache computed prompt token ids
- self._prompt_token_ids = cast(List[int],
- self.inputs.get(prompt_token_ids_key))
- return self._prompt_token_ids
- @property
- def multi_modal_data(self) -> "MultiModalDataDict":
- return self.inputs.get("multi_modal_data") or {}
- @property
- def lora_int_id(self) -> int:
- return self.lora_request.lora_int_id if self.lora_request else 0
- @property
- def prompt_adapter_id(self) -> int:
- return self.prompt_adapter_request.prompt_adapter_id \
- if self.prompt_adapter_request else 0
- def get_output_text_to_return(self, buffer_length: int):
- # We return the full output text if the sequence is finished.
- truncate = buffer_length and not self.is_finished()
- return self.output_text[:-buffer_length] if truncate else (
- self.output_text)
- def hash_of_block(self, logical_idx: int) -> int:
- # TODO This can produce incorrect hash when block size > prompt size
- # Compute the number of tokens in the sequence
- # TODO: The current hashing function is O(L^2). We should optimize
- # this in the future.
- num_tokens = self.num_hashed_tokens_of_block(logical_idx)
- hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
- return hash((hashed_tokens, self.lora_int_id))
- def num_hashed_tokens_of_block(self, logical_idx: int):
- return logical_idx * self.block_size + self.block_size
- def reset_state_for_recompute(self):
- """Reset the sequence states for recomputation."""
- self.data.reset_state_for_recompute()
- def append_token_id(self, token_id: int, logprobs: Dict[int,
- Logprob]) -> None:
- assert token_id in logprobs
- self.output_logprobs.append(logprobs)
- self.data.append_token_id(token_id, logprobs[token_id].logprob)
- def get_len(self) -> int:
- return self.data.get_len()
- def get_prompt_len(self) -> int:
- return self.data.get_prompt_len()
- def get_output_len(self) -> int:
- return self.data.get_output_len()
- def get_token_ids(self) -> List[int]:
- return self.data.get_token_ids()
- def get_prompt_token_ids(self) -> Tuple[int, ...]:
- return self.data.get_prompt_token_ids()
- def get_last_token_id(self) -> int:
- return self.data.get_last_token_id()
- def get_output_token_ids(self) -> Tuple[int, ...]:
- return self.data.get_output_token_ids()
- def get_cumulative_logprob(self) -> float:
- return self.data.cumulative_logprob
- def get_beam_search_score(self,
- length_penalty: float = 1.0,
- seq_len: Optional[int] = None,
- eos_token_id: Optional[int] = None) -> float:
- """Calculate the beam search score with length penalty.
- Adapted from
- https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
- """
- if seq_len is None:
- seq_len = self.get_len()
- # NOTE: HF implementation does not count the EOS token
- # towards the length, we align with that here for testing.
- if (eos_token_id is not None
- and self.get_last_token_id() == eos_token_id):
- seq_len -= 1
- return self.get_cumulative_logprob() / (seq_len**length_penalty)
- def is_finished(self) -> bool:
- return SequenceStatus.is_finished(self.status)
- def fork(self, new_seq_id: int) -> "Sequence":
- new_seq = copy.deepcopy(self)
- new_seq.seq_id = new_seq_id
- return new_seq
- def get_num_new_tokens(self) -> int:
- """Get the number of new tokens to be computed.
- Returns:
- The new number of tokens to be computed. I.e., 1 for decode, or
- the remaining prompt size for prefill.
- """
- if self.data.stage == SequenceStage.DECODE:
- return 1
- return self.data.get_num_uncomputed_tokens()
- def is_prefill(self) -> bool:
- return self.data.stage == SequenceStage.PREFILL
- def __repr__(self) -> str:
- return (f"Sequence(seq_id={self.seq_id}, "
- f"status={self.status.name}, "
- f"num_blocks={self.n_blocks}, ")
- class SequenceGroupState(msgspec.Struct,
- omit_defaults=True): # type: ignore[call-arg]
- """Mutable state tied to a specific sequence group"""
- # for multi-step decoding
- num_steps: int = 1
- current_step: int = 0
- @property
- def remaining_steps(self) -> int:
- return self.num_steps - self.current_step
- class SequenceGroup:
- """A group of sequences that are generated from the same prompt.
- Args:
- request_id: The ID of the request.
- seqs: The list of sequences.
- sampling_params: The sampling parameters used to generate the outputs.
- arrival_time: The arrival time of the request.
- lora_request: LoRA request.
- embeddings: The embeddings vectors of the prompt of the sequence group
- for an embedding model.
- pooling_params: The pooling parameters used to generate the pooling
- for an embedding model.
- encoder_seq: Optional, the single encoder sequence. Should be None
- unless you are working with an encoder/decoder model.
- prompt_adapter_request: Prompt Adapter request.
- """
- def __init__(
- self,
- request_id: str,
- seqs: List[Sequence],
- arrival_time: float,
- sampling_params: Optional[SamplingParams] = None,
- lora_request: Optional[LoRARequest] = None,
- embeddings: Optional[List[float]] = None,
- pooling_params: Optional[PoolingParams] = None,
- encoder_seq: Optional[Sequence] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> None:
- self.request_id = request_id
- self.seqs = seqs
- self.is_single_seq = len(seqs) == 1
- self.seqs_dict = {seq.seq_id: seq for seq in seqs}
- self.sampling_params = sampling_params
- self.metrics = RequestMetrics(arrival_time=arrival_time,
- last_token_time=arrival_time,
- first_scheduled_time=None,
- first_token_time=None,
- time_in_queue=None)
- self.lora_request = lora_request
- self.prompt_logprobs: Optional[PromptLogprobs] = None
- self.state = SequenceGroupState()
- self.embeddings = embeddings
- self.pooling_params = pooling_params
- self.prompt_adapter_request = prompt_adapter_request
- self.encoder_seq = encoder_seq
- @property
- def prompt(self) -> Optional[str]:
- # All sequences in the group should have the same prompt.
- # We use the prompt of an arbitrary sequence.
- return self.seqs[0].prompt
- @property
- def prompt_token_ids(self) -> List[int]:
- # All sequences in the group should have the same prompt.
- # We use the prompt of an arbitrary sequence.
- return self.seqs[0].prompt_token_ids
- @property
- def encoder_prompt(self) -> Optional[str]:
- # There are either 0 or 1 encoder sequences
- # If one is present, its prompt is distinct
- # from the decoder's.
- return (self.encoder_seq.prompt
- if self.encoder_seq is not None else None)
- @property
- def encoder_prompt_token_ids(self) -> Optional[List[int]]:
- # There are either 0 or 1 encoder sequences
- # If one is present, its prompt token ids are
- # distinct from the decoder's.
- return (self.encoder_seq.prompt_token_ids
- if self.encoder_seq is not None else None)
- @property
- def multi_modal_data(self) -> "MultiModalDataDict":
- # All sequences in the group should have the same multi-modal data.
- # We use the multi-modal data of an arbitrary sequence.
- return self.seqs[0].multi_modal_data
- @property
- def lora_int_id(self) -> int:
- return self.lora_request.lora_int_id if self.lora_request else 0
- @property
- def prompt_adapter_id(self) -> int:
- return self.prompt_adapter_request.prompt_adapter_id \
- if self.prompt_adapter_request else 0
- @property
- def prompt_adapter_num_virtual_tokens(self) -> int:
- return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
- if self.prompt_adapter_request else 0
- def init_multi_step(self, num_scheduler_steps: int) -> None:
- self.state.num_steps = num_scheduler_steps
- self.state.current_step = 0
- def get_last_latency(self, now: float) -> Optional[float]:
- """Sets the last token time for Request level timings."""
- # If still in prefill phase, raise Error.
- if self.is_prefill():
- pass
- # Otherwise return token latency.
- latency = now - self.metrics.last_token_time
- self.metrics.last_token_time = now
- return latency
- def maybe_set_first_token_time(self, time: float) -> None:
- """Sets the first token time for Request level timings."""
- # NOTE: in a case where a sequence_group is swapped and
- # recomputed, the time between iterations is counted
- # in TPOT, rather than recalculating TTFT (since from the )
- # POV of the user, there is simply a long generation delay.
- if (self.metrics.first_token_time is None
- and self.seqs[0].get_output_len() == 1):
- self.metrics.first_token_time = time
- def maybe_set_first_scheduled_time(self, time: float) -> None:
- """Sets the first scheduled time and time in queue for Request
- level timings."""
- if self.metrics.first_scheduled_time is None:
- self.metrics.first_scheduled_time = time
- self.metrics.time_in_queue = time - self.metrics.arrival_time
- def set_finished_time(self, time: Optional[float]) -> None:
- """Sets the finished time for Request level timings."""
- self.metrics.finished_time = time
- def get_max_num_running_seqs(self) -> int:
- """The maximum number of sequences running in parallel in the remaining
- lifetime of the request."""
- if self.sampling_params and self.sampling_params.use_beam_search:
- # For beam search, maximally there will always be `best_of` beam
- # candidates running in the future.
- best_of = self.sampling_params.best_of
- assert isinstance(best_of, int)
- return best_of
- else:
- if self.sampling_params:
- best_of = self.sampling_params.best_of
- assert isinstance(best_of, int)
- if best_of > self.num_seqs():
- # At prompt stage, the sequence group is not yet filled up
- # and only have one sequence running. However, in the
- # generation stage, we will have `best_of` sequences
- # running
- return best_of
- # At sampling stages, return the number of actual sequences
- # that are not finished yet.
- return self.num_unfinished_seqs()
- def get_seqs(
- self,
- status: Optional[SequenceStatus] = None,
- ) -> List[Sequence]:
- if status is None:
- return self.seqs
- if self.is_single_seq:
- return self.seqs if self.seqs[0].status == status else []
- return [seq for seq in self.seqs if seq.status == status]
- def is_encoder_decoder(self) -> bool:
- return self.encoder_seq is not None
- def get_encoder_seq(self) -> Optional[Sequence]:
- return self.encoder_seq
- def get_unfinished_seqs(self) -> List[Sequence]:
- if self.is_single_seq:
- return self.seqs if not self.seqs[0].is_finished() else []
- return [seq for seq in self.seqs if not seq.is_finished()]
- def get_finished_seqs(self) -> List[Sequence]:
- if self.is_single_seq:
- return self.seqs if self.seqs[0].is_finished() else []
- return [seq for seq in self.seqs if seq.is_finished()]
- def update_num_computed_tokens(self, num_new_computed_tokens: int):
- """Update number of tokens computed so far."""
- for seq in self.seqs:
- if not seq.is_finished():
- seq.data.update_num_computed_tokens(num_new_computed_tokens)
- def get_num_uncomputed_tokens(self) -> int:
- num_uncomputed_tokens = 0
- for seq in self.seqs:
- if not seq.is_finished():
- num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
- return num_uncomputed_tokens
- def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
- # Optimization. We don't need to call get_seqs if we don't need to
- # filter by states.
- if status is None:
- return len(self.seqs)
- if self.is_single_seq:
- return 1 if self.seqs[0].status == status else 0
- return len(self.get_seqs(status))
- def num_unfinished_seqs(self) -> int:
- if self.is_single_seq:
- return 1 if not self.seqs[0].is_finished() else 0
- return len(self.get_unfinished_seqs())
- def num_finished_seqs(self) -> int:
- if self.is_single_seq:
- return 1 if self.seqs[0].is_finished() else 0
- return len(self.get_finished_seqs())
- def find(self, seq_id: int) -> Sequence:
- if seq_id not in self.seqs_dict:
- raise ValueError(f"Sequence {seq_id} not found.")
- return self.seqs_dict[seq_id]
- def add(self, seq: Sequence) -> None:
- if seq.seq_id in self.seqs_dict:
- raise ValueError(f"Sequence {seq.seq_id} already exists.")
- self.seqs_dict[seq.seq_id] = seq
- self.seqs.append(seq)
- self.is_single_seq = len(self.seqs) == 1
- def remove(self, seq_id: int) -> None:
- seq = self.seqs_dict.pop(seq_id, None)
- if seq is None:
- raise ValueError(f"Sequence {seq_id} not found.")
- self.seqs.remove(seq)
- self.is_single_seq = len(self.seqs) == 1
- def is_finished(self) -> bool:
- return all(seq.is_finished() for seq in self.seqs)
- def is_prefill(self) -> bool:
- # Every sequence should be in the same stage.
- return self.seqs[0].is_prefill()
- def __repr__(self) -> str:
- return (f"SequenceGroup(request_id={self.request_id}, "
- f"sampling_params={self.sampling_params}, "
- f"num_seqs={len(self.seqs)})")
- class SequenceGroupMetadataDelta(
- msgspec.Struct,
- tag=True, # type: ignore[call-arg]
- array_like=True, # type: ignore[call-arg]
- omit_defaults=True): # type: ignore[call-arg]
- """Delta of SequenceGroupMetadata.
- After sending the first SequenceGroupMetadata, vLLM scheduler
- only sends delta to reduce the data payload size.
- """
- seq_data_delta: Dict[int, SequenceDataDelta]
- request_id: str
- block_tables: Dict[int, List[int]]
- is_prompt: bool
- do_sample: bool = True
- token_chunk_size: Optional[int] = None
- computed_block_nums: Optional[List[int]] = None
- state: Optional[SequenceGroupState] = msgspec.field(
- default_factory=lambda: SequenceGroupState())
- class SequenceGroupMetadata(
- msgspec.Struct,
- tag=True, # type: ignore[call-arg]
- array_like=True, # type: ignore[call-arg]
- omit_defaults=True): # type: ignore[call-arg]
- """Metadata for a sequence group. Used to create `AttentionMetadata`.
- Args:
- request_id: The ID of the request.
- is_prompt: Whether the request is at prompt stage.
- seq_data: The sequence data. (Seq id -> sequence data)
- sampling_params: The sampling parameters used to generate the outputs.
- block_tables: The block tables. (Seq id -> list of physical block
- numbers)
- do_sample: True if sampling is required. Sampling is not required when
- e.g., prefill is chunked, and the current iteration only computes
- query tokens for prefill, we don't need sampling.
- token_chunk_size: The number of tokens to be processed (per sequence).
- None if chunking is not required.
- lora_request: LoRA request.
- computed_block_nums: The block numbers that are already computed,
- used in prefix caching.
- state: Internal state tied to this sequence group.
- multi_modal_data: Multi modal data.
- encoder_seq_data: Optional sequence data for encoder prompt
- (SequenceGroup.encoder_seq). Should be None
- unless you are working with an encoder/decoder
- model.
- cross_block_table: Optional cross-attention block table associated
- with the encoder prompt
- (SequenceGroup.encoder_seq). Should be None
- unless you are working with an encoder/decoder
- model.
- prompt_adapter_request: Prompt Adapter request.
- """
- request_id: str
- is_prompt: bool
- seq_data: Dict[int, SequenceData]
- sampling_params: SamplingParams
- block_tables: Dict[int, List[int]]
- do_sample: bool = True
- pooling_params: Optional[PoolingParams] = None
- lora_request: Optional[LoRARequest] = None
- computed_block_nums: Optional[List[int]] = None
- state: Optional[SequenceGroupState] = msgspec.field(
- default_factory=lambda: SequenceGroupState())
- # "MultiModalDataDict" types. We have to use Any due to msgspec
- # doesn't allow to have union of 2 different dicts.
- multi_modal_data: Optional[Any] = None
- encoder_seq_data: Optional[SequenceData] = None
- cross_block_table: Optional[List[int]] = None
- prompt_adapter_request: Optional[PromptAdapterRequest] = None
- token_chunk_size: Optional[int] = None
- ### Stateful fields that are lazily defined. ###
- # The number of speculative tokens adopted in this request.
- # None means specuative decoding is not used.
- # Zero means speculative decoding is disabled for some reasons.
- # TODO: We should maintain this states out of the sequence group.
- num_speculative_tokens: Optional[int] = None
- def __post_init__(self):
- if self.seq_data is not None and self.token_chunk_size is None:
- if self.is_prompt:
- self.token_chunk_size = next(iter(
- self.seq_data.values())).get_len()
- else:
- self.token_chunk_size = 1
- @property
- def lora_int_id(self) -> int:
- return self.lora_request.lora_int_id if self.lora_request else 0
- @property
- def prompt_adapter_id(self) -> int:
- return self.prompt_adapter_request.prompt_adapter_id \
- if self.prompt_adapter_request else 0
- @property
- def prompt_adapter_num_virtual_tokens(self) -> int:
- return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
- if self.prompt_adapter_request else 0
- def apply_delta(self,
- sequence_group_metadata_delta: SequenceGroupMetadataDelta):
- for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
- self.seq_data[id].apply_delta(delta)
- assert self.request_id == sequence_group_metadata_delta.request_id
- self.block_tables = sequence_group_metadata_delta.block_tables
- self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
- self.do_sample = sequence_group_metadata_delta.do_sample
- self.is_prompt = sequence_group_metadata_delta.is_prompt
- def finish_step(self) -> None:
- assert self.state is not None
- assert self.state.current_step < self.state.num_steps
- self.state.current_step += 1
- class SequenceOutput(
- msgspec.Struct,
- omit_defaults=True, # type: ignore[call-arg]
- array_like=True): # type: ignore[call-arg]
- """The model output associated with a sequence.
- Args:
- parent_seq_id: The ID of the parent sequence (for forking in beam
- search).
- output_token: The output token ID.
- logprobs: The logprobs of the output token.
- (Token id -> logP(x_i+1 | x_0, ..., x_i))
- """
- parent_seq_id: int
- output_token: int
- logprobs: Dict[int, Logprob]
- def __repr__(self) -> str:
- return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
- f"output_token={self.output_token}, "
- f"logprobs={self.logprobs})")
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, SequenceOutput):
- raise NotImplementedError()
- equal = (self.parent_seq_id == other.parent_seq_id
- and self.output_token == other.output_token)
- log_probs_equal = other.logprobs == self.logprobs
- return equal and log_probs_equal
- class SequenceGroupOutput(ABC):
- """The base class for model outputs associated with a sequence group."""
- @abstractmethod
- def __repr__(self) -> str:
- pass
- @abstractmethod
- def __eq__(self, other: object) -> bool:
- pass
- class CompletionSequenceGroupOutput(
- msgspec.Struct,
- omit_defaults=True, # type: ignore[call-arg]
- array_like=True): # type: ignore[call-arg]
- __metaclass__ = SequenceGroupOutput
- """The model output associated with a completion sequence group."""
- samples: List[SequenceOutput]
- # Prompt logprob for each prompt query token.
- prompt_logprobs: Optional[PromptLogprobs]
- def __repr__(self) -> str:
- return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
- f"prompt_logprobs={self.prompt_logprobs})")
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, CompletionSequenceGroupOutput):
- raise NotImplementedError()
- return (self.samples == other.samples
- and self.prompt_logprobs == other.prompt_logprobs)
- class EmbeddingSequenceGroupOutput(
- msgspec.Struct,
- omit_defaults=True, # type: ignore[call-arg]
- array_like=True, # type: ignore[call-arg]
- ):
- """The model output associated with an embedding sequence group."""
- __metaclass__ = SequenceGroupOutput
- embeddings: List[int]
- def __repr__(self) -> str:
- return (f"EmbeddingSequenceGroupOutput("
- f"embeddings_shape={len(self.embeddings)})")
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, EmbeddingSequenceGroupOutput):
- raise NotImplementedError()
- return self.embeddings == other.embeddings
- class IntermediateTensors(
- msgspec.Struct,
- omit_defaults=True, # type: ignore[call-arg]
- array_like=True): # type: ignore[call-arg]
- """For all pipeline stages except the last, we need to return the hidden
- states and residuals to be sent to the next stage. This data structure
- contains the hidden states and residuals for a request.
- """
- tensors: Dict[str, torch.Tensor]
- def __getitem__(self, key: Union[str, slice]):
- if isinstance(key, str):
- return self.tensors[key]
- elif isinstance(key, slice):
- return self.__class__({k: v[key] for k, v in self.tensors.items()})
- def __setitem__(self, key: str, value):
- self.tensors[key] = value
- def __len__(self):
- return len(self.tensors)
- def __eq__(self, other: object):
- return isinstance(other, self.__class__) and self
- def __repr__(self) -> str:
- return f"IntermediateTensors(tensors={self.tensors})"
- class SamplerOutput(
- msgspec.Struct,
- omit_defaults=True, # type: ignore[call-arg]
- array_like=True): # type: ignore[call-arg]
- """For each sequence group, we generate a list of SequenceOutput object,
- each of which contains one possible candidate for the next token.
- This data structure implements methods, so it can be used like a list, but
- also has optional fields for device tensors.
- """
- outputs: List[CompletionSequenceGroupOutput]
- # On-device tensor containing probabilities of each token.
- sampled_token_probs: Optional[torch.Tensor] = None
- # On-device tensor containing the logprobs of each token.
- logprobs: Optional["torch.Tensor"] = None
- # On-device tensor containing the sampled token ids.
- sampled_token_ids: Optional[torch.Tensor] = None
- # CPU tensor containing the sampled token ids. Used during multi-step to
- # return the sampled token ids from last rank to AsyncLLMEngine to be
- # 'broadcasted' to all other PP ranks for next step.
- sampled_token_ids_cpu: Optional[torch.Tensor] = None
- # Spec decode metrics populated by workers.
- spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
- # Optional last hidden states from the model.
- hidden_states: Optional[torch.Tensor] = None
- # Optional prefill hidden states from the model
- # (used for models like EAGLE).
- prefill_hidden_states: Optional[torch.Tensor] = None
- # Time taken in the forward pass for this across all workers
- model_forward_time: Optional[float] = None
- def __getitem__(self, idx: int):
- return self.outputs[idx]
- def __setitem__(self, idx: int, value):
- self.outputs[idx] = value
- def __len__(self):
- return len(self.outputs)
- def __eq__(self, other: object):
- return isinstance(other,
- self.__class__) and self.outputs == other.outputs
- def __repr__(self) -> str:
- """Show the shape of a tensor instead of its values to reduce noise.
- """
- sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
- else self.sampled_token_probs.shape)
- sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
- self.sampled_token_ids.shape)
- return (
- f"SamplerOutput(outputs={self.outputs}, "
- f"sampled_token_probs={sampled_token_probs_repr}, "
- f"sampled_token_ids={sampled_token_ids_repr}, "
- f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
- class PoolerOutput(
- msgspec.Struct,
- omit_defaults=True, # type: ignore[call-arg]
- array_like=True): # type: ignore[call-arg]
- """The output from a pooling operation in the embedding model."""
- outputs: List[EmbeddingSequenceGroupOutput]
- spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
- def __getitem__(self, idx: int):
- return self.outputs[idx]
- def __setitem__(self, idx: int, value):
- self.outputs[idx] = value
- def __len__(self):
- return len(self.outputs)
- def __eq__(self, other: object):
- return isinstance(other,
- self.__class__) and self.outputs == other.outputs
- def get_all_seq_ids(
- seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
- """Given a list of SequenceGroupMetadata, create a list of all
- sequence ids.
- """
- return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
- def get_all_seq_ids_and_request_ids(
- seq_group_metadata_list: List[SequenceGroupMetadata]
- ) -> Tuple[List[int], Dict[str, Set[int]]]:
- """Given a list of SequenceGroupMetadata, create a list of all
- sequence ids.
- """
- seq_ids: List[int] = []
- request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
- for sg in seq_group_metadata_list:
- for seq_id in sg.seq_data:
- seq_ids.append(seq_id)
- request_id_seq_ids_mapping[sg.request_id].add(seq_id)
- return seq_ids, request_id_seq_ids_mapping
- class HiddenStates(msgspec.Struct, array_like=True,
- omit_defaults=True): # type: ignore[call-arg]
- """Hidden states corresponding to in-progress sequences.
- Used in speculative decoding to pass hidden states from
- the target model to the proposer model.
- seq_ids are the sequence ids of each entry of the batch
- dimension of the hidden_states tensor"""
- # Scorer hidden states. For prefill step, it is used for hidden states of
- # all tokens, whereas for decode step, it use used for last accepted tokens.
- hidden_states: torch.Tensor
- # The sequence group metadata list. Only needed for decode step.
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
- # Scorer hidden states of the 2nd last token proposed by the proposer (
- # irrespective of whether it was accepted or not). Only used for cases when
- # last proposed token is accepted (i.e., in case of bonus tokens). For the
- # case of no bonus tokens, these are ignored.
- second_last_token_hidden_states: Optional[torch.Tensor] = None
- _seq_ids: List[int] = msgspec.field(default_factory=list)
- def __post_init__(self):
- if self.seq_group_metadata_list is not None:
- assert len(self.seq_group_metadata_list) == len(self.hidden_states)
- self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
- @property
- def seq_ids(self) -> List[int]:
- return self._seq_ids
- def update(self,
- hidden_states: torch.Tensor,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- second_last_token_hidden_states: Optional[torch.Tensor] = None):
- """Update hidden states from target model invocation. Only used for
- decode steps"""
- assert len(seq_group_metadata_list) == len(hidden_states)
- self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
- self.hidden_states = torch.cat([self.hidden_states, hidden_states])
- if self.second_last_token_hidden_states is not None:
- # Adding dummy hidden_states to this to maintain same shape
- self.second_last_token_hidden_states = torch.cat([
- self.second_last_token_hidden_states,
- torch.zeros_like(hidden_states)
- if second_last_token_hidden_states is None else
- second_last_token_hidden_states
- ])
- def prune(self,
- seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
- """Prune to provided list of sequence ids. Only used for decode steps.
- """
- # Currently this prunes all seq_ids not present in
- # seq_group_metadata_list which might cause problems where a sequence
- # may be "paused" then "resumed" later. This should only prune sequences
- # which are confirmed to be aborted.
- seq_ids = get_all_seq_ids(seq_group_metadata_list)
- if seq_ids != self._seq_ids:
- # Batch contents changed - prune removed sequences.
- index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
- self.hidden_states = self.hidden_states[index]
- if self.second_last_token_hidden_states is not None:
- self.second_last_token_hidden_states = self\
- .second_last_token_hidden_states[index]
- self._seq_ids = seq_ids
- def expand_with_bonus_tokens(
- self, seq_with_bonus_token_in_last_step: set) -> None:
- """Expand hidden states for sequences with bonus tokens. This is in
- alignment with `MultiStepWorker._expand_execute_model_request`."""
- if self.second_last_token_hidden_states is None \
- or not seq_with_bonus_token_in_last_step:
- return
- index = []
- for seq_id in self._seq_ids:
- i = self._seq_ids.index(seq_id)
- if seq_id in seq_with_bonus_token_in_last_step:
- index.append(i + len(self._seq_ids))
- index.append(i)
- self.hidden_states = torch.cat(
- [self.hidden_states, self.second_last_token_hidden_states])[index]
- class ExecuteModelRequest(
- msgspec.Struct,
- array_like=True, # type: ignore[call-arg]
- omit_defaults=True): # type: ignore[call-arg]
- """The model execution request, containing CPU metadata only. The LLM
- engine should create an instance of this class for each request batch."""
- # The sequence group metadata list.
- seq_group_metadata_list: List[Union[SequenceGroupMetadata,
- SequenceGroupMetadataDelta]]
- # Blocks to swap in. List of CPU -> GPU block number.
- blocks_to_swap_in: List[Tuple[int,
- int]] = msgspec.field(default_factory=list)
- # Blocks to swap out. List of GPU -> CPU block number.
- blocks_to_swap_out: List[Tuple[int,
- int]] = msgspec.field(default_factory=list)
- # Blocks to copy. Source to dest block.
- blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
- # Virtual engine ID for pipeline parallel.
- virtual_engine: int = 0
- # The number of slots for lookahead decoding.
- num_lookahead_slots: int = 0
- # The number of requests in the running queue.
- running_queue_size: int = 0
- # Optional hidden states from prior step.
- previous_hidden_states: Optional[HiddenStates] = None
- # The number of forward steps to run.
- num_steps: int = 1
- # Finished request ids since last step.
- finished_requests_ids: List[str] = msgspec.field(default_factory=list)
- # The last sampled token ids for multi step decoding.
- last_sampled_token_ids: Optional[torch.Tensor] = None
- # Async postprocessor
- output_proc_callback_fn: Optional[Callable] = None
- @property
- def is_first_multi_step(self) -> bool:
- # TODO: make this be able to handle batches with variable number of
- # steps
- assert len(self.seq_group_metadata_list) > 0
- first_seq_group = self.seq_group_metadata_list[0]
- assert first_seq_group.state is not None
- return first_seq_group.state.current_step == 0
- @property
- def is_last_step(self) -> bool:
- # TODO: make this be able to handle batches with variable number of
- # steps
- assert len(self.seq_group_metadata_list) > 0
- first_seq_group = self.seq_group_metadata_list[0]
- assert first_seq_group.state is not None
- return first_seq_group.state.remaining_steps == 1
- @property
- def current_step(self) -> int:
- # TODO: make this be able to handle batches with variable number of
- # steps
- assert len(self.seq_group_metadata_list) > 0
- state = self.seq_group_metadata_list[0].state
- assert state is not None
- return state.current_step
- def clone(
- self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
- SequenceGroupMetadataDelta]]
- ) -> "ExecuteModelRequest":
- """Clone the request with a new sequence group metadata list."""
- return ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- blocks_to_swap_in=self.blocks_to_swap_in.copy(),
- blocks_to_swap_out=self.blocks_to_swap_out.copy(),
- blocks_to_copy=self.blocks_to_copy.copy(),
- virtual_engine=self.virtual_engine,
- num_lookahead_slots=self.num_lookahead_slots,
- running_queue_size=self.running_queue_size,
- previous_hidden_states=self.previous_hidden_states,
- num_steps=self.num_steps,
- finished_requests_ids=self.finished_requests_ids,
- last_sampled_token_ids=self.last_sampled_token_ids.clone()
- if self.last_sampled_token_ids is not None else None,
- output_proc_callback_fn=self.output_proc_callback_fn)
|