123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754 |
- """Sequence and its related classes."""
- import copy
- import enum
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Union, TYPE_CHECKING
- from aphrodite.common.block import LogicalTokenBlock
- from aphrodite.common.sampling_params import SamplingParams, SamplingType
- from aphrodite.lora.request import LoRARequest
- if TYPE_CHECKING:
- import torch
- from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
- @dataclass
- class Logprob:
- """Infos for supporting OpenAI compatible logprobs and token ranks.
- Attributes:
- logprob: The logprob of chosen token
- rank: The vocab rank of chosen token (>=1)
- decoded_token: The decoded chosen token index
- """
- logprob: float
- rank: Optional[int] = None
- decoded_token: Optional[str] = None
- PromptLogprobs = List[Optional[Dict[int, Logprob]]]
- SampleLogprobs = List[Dict[int, Logprob]]
- class SequenceStatus(enum.Enum):
- """Status of a sequence."""
- WAITING = enum.auto()
- RUNNING = enum.auto()
- SWAPPED = enum.auto()
- FINISHED_STOPPED = enum.auto()
- FINISHED_LENGTH_CAPPED = enum.auto()
- FINISHED_ABORTED = enum.auto()
- FINISHED_IGNORED = enum.auto()
- @staticmethod
- def is_finished(status: "SequenceStatus") -> bool:
- return status in [
- SequenceStatus.FINISHED_STOPPED,
- SequenceStatus.FINISHED_LENGTH_CAPPED,
- SequenceStatus.FINISHED_ABORTED,
- SequenceStatus.FINISHED_IGNORED,
- ]
- @staticmethod
- def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
- if status == SequenceStatus.FINISHED_STOPPED:
- finish_reason = "stop"
- elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
- finish_reason = "length"
- elif status == SequenceStatus.FINISHED_ABORTED:
- finish_reason = "abort"
- elif status == SequenceStatus.FINISHED_IGNORED:
- # The ignored sequences are the sequences whose prompt lengths
- # are longer than the model's length cap. Therefore, the stop
- # reason should also be "length" as in OpenAI API.
- finish_reason = "length"
- else:
- finish_reason = None
- return finish_reason
- class SequenceStage(enum.Enum):
- PREFILL = enum.auto()
- DECODE = enum.auto()
- @dataclass
- class RequestMetrics:
- """Metrics associated with a request.
- Args:
- arrival_time: The time when the request arrived.
- first_scheduled_time: The time when the request was first scheduled.
- first_token_time: The time when the first token was generated.
- time_in_queue: The time the request spent in the queue.
- finished_time: The time when the request was finished.
- """
- arrival_time: float
- last_token_time: float
- first_scheduled_time: Optional[float]
- first_token_time: Optional[float]
- time_in_queue: Optional[float]
- finished_time: Optional[float] = None
- class SequenceData:
- """Data associated with a sequence.
- Args:
- prompt_token_ids: The token IDs of the prompt.
- output_token_ids: The token IDs of the output. Set to an empty list if
- None.
- Attributes:
- prompt_token_ids: The token IDs of the prompt.
- output_token_ids: The token IDs of the output.
- cumulative_logprob: The cumulative log probability of the output.
- """
- def __init__(
- self,
- prompt_token_ids: List[int],
- output_token_ids: Optional[List[int]] = None,
- ) -> None:
- if output_token_ids is None:
- output_token_ids = []
- self.prompt_token_ids = prompt_token_ids
- self.output_token_ids = output_token_ids
- self.cumulative_logprob = 0.0
- # The number of tokens that are computed (that run against the model).
- self._num_computed_tokens = 0
- self._stage: SequenceStage = SequenceStage.PREFILL
- # The ID to identify the sequences in the same group
- self.inner_id = 0
- def append_token_id(self, token_id: int, logprob: float) -> None:
- self.output_token_ids.append(token_id)
- self.cumulative_logprob += logprob
- def get_len(self) -> int:
- return len(self.output_token_ids) + len(self.prompt_token_ids)
- def get_prompt_len(self) -> int:
- return len(self.prompt_token_ids)
- def get_output_len(self) -> int:
- return len(self.output_token_ids)
- def get_token_ids(self) -> List[int]:
- return self.prompt_token_ids + self.output_token_ids
- def get_num_computed_tokens(self) -> int:
- """Return the number of prefill tokens that are already computed."""
- return self._num_computed_tokens
- def update_num_computed_tokens(self, num_new_computed_tokens: int):
- """Update number of tokens computed so far."""
- self._num_computed_tokens += num_new_computed_tokens
- assert self._num_computed_tokens <= self.get_len(), (
- self._num_computed_tokens, self.get_len())
- # If all tokens are computed, it means it is in decoding phase.
- if self.get_num_uncomputed_tokens() == 0:
- self._stage = SequenceStage.DECODE
- def reset_state_for_recompute(self) -> None:
- """Reset the number of computed tokens from this sequence. It is
- supposed to be called when a sequence needs to be started from
- the beginning again (e.g., sequence is preempted).
- """
- self._num_computed_tokens = 0
- self._stage = SequenceStage.PREFILL
- def get_num_uncomputed_tokens(self) -> int:
- """Return the number of prefil tokens that are not computed."""
- # we use `get_len()` which includes prompt_len + output_len instead
- # of prompt_len here. This is because during recompute we need to
- # prefill for both prompt and output.
- return self.get_len() - self.get_num_computed_tokens()
- def get_last_token_id(self) -> int:
- if not self.output_token_ids:
- return self.prompt_token_ids[-1]
- return self.output_token_ids[-1]
- def get_prompt_token_ids(self) -> int:
- return self.prompt_token_ids
- def get_output_token_ids(self) -> int:
- return self.output_token_ids
- @property
- def stage(self) -> SequenceStage:
- return self._stage
- def __repr__(self) -> str:
- return (f"SequenceData("
- f"prompt_token_ids={self.prompt_token_ids}, "
- f"output_token_ids={self.output_token_ids}, "
- f"cumulative_logprob={self.cumulative_logprob})")
- class Sequence:
- """Stores the data, status, and block information of a sequence.
- Args:
- seq_id: The ID of the sequence.
- prompt: The prompt of the sequence.
- prompt_token_ids: The token IDs of the prompt.
- block_size: The block size of the sequence. Should be the same as the
- block size used by the block manager and cache engine.
- lora_request: LoRA request.
- """
- def __init__(
- self,
- seq_id: int,
- prompt: str,
- prompt_token_ids: List[int],
- block_size: int,
- eos_token_id: Optional[int] = None,
- lora_request: Optional[LoRARequest] = None,
- ) -> None:
- self.seq_id = seq_id
- self.prompt = prompt
- self.block_size = block_size
- self.eos_token_id = eos_token_id
- self.lora_request = lora_request
- self.data = SequenceData(prompt_token_ids)
- self.output_logprobs: SampleLogprobs = []
- self.output_text = ""
- self.logical_token_blocks: List[LogicalTokenBlock] = []
- # Initialize the logical token blocks with the prompt token ids.
- self._append_tokens_to_blocks(prompt_token_ids)
- self.status = SequenceStatus.WAITING
- self.stop_reason: Union[int, str, None] = None
- # Used for incremental detokenization
- self.prefix_offset = 0
- self.read_offset = 0
- # Input + output tokens
- self.tokens: Optional[List[str]] = None
- self.persistent_data = {}
- self.seq_group: Optional[SequenceGroup] = None
- # The id to identify the sequences in same Group
- self.inner_id = 0
- @property
- def lora_int_id(self) -> int:
- return self.lora_request.lora_int_id if self.lora_request else 0
- def get_output_text_to_return(self, buffer_length: int):
- # We return the full output text if the sequence is finished.
- truncate = buffer_length and not self.is_finished()
- return self.output_text[:-buffer_length] if truncate else (
- self.output_text)
- def hash_of_block(self, logical_idx: int) -> int:
- # Compute the number of tokens in the sequence
- # TODO: The current hashing function is O(L^2). We should optimize
- # this in the future.
- num_tokens = self.num_hashed_tokens_of_block(logical_idx)
- return hash(
- (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
- def num_hashed_tokens_of_block(self, logical_idx: int):
- return logical_idx * self.block_size + self.block_size
- def reset_state_for_recompute(self):
- """Reset the sequence states for recomputation."""
- self.data.reset_state_for_recompute()
- def _append_logical_block(self) -> None:
- block = LogicalTokenBlock(
- block_number=len(self.logical_token_blocks),
- block_size=self.block_size,
- )
- self.logical_token_blocks.append(block)
- def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
- cursor = 0
- while cursor < len(token_ids):
- if not self.logical_token_blocks:
- self._append_logical_block()
- last_block = self.logical_token_blocks[-1]
- if last_block.is_full():
- self._append_logical_block()
- last_block = self.logical_token_blocks[-1]
- num_empty_slots = last_block.get_num_empty_slots()
- last_block.append_tokens(token_ids[cursor:cursor +
- num_empty_slots])
- cursor += num_empty_slots
- def append_token_id(
- self,
- token_id: int,
- logprobs: Dict[int, Logprob],
- ) -> None:
- assert token_id in logprobs
- seq_group = self.seq_group
- def calc(seq_group):
- num_token = 0
- for block in seq_group.find(
- seq_group.root_seq_id).logical_token_blocks:
- num_token += block.num_tokens
- return num_token
- # allocate block for root seq if sampling type is RANDOM_SEED and
- # generate multiple sequence per prompt
- if self.seq_id != seq_group.root_seq_id and \
- seq_group.sampling_params.sampling_type in (
- SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- buffer_seq = seq_group.find(self.seq_group.root_seq_id)
- else:
- buffer_seq = self
- buffer_seq._append_tokens_to_blocks([token_id])
- self.output_logprobs.append(logprobs)
- self.data.append_token_id(token_id, logprobs[token_id].logprob)
- def get_len(self) -> int:
- return self.data.get_len()
- def get_prompt_len(self) -> int:
- return self.data.get_prompt_len()
- def get_output_len(self) -> int:
- return self.data.get_output_len()
- def get_token_ids(self) -> List[int]:
- return self.data.get_token_ids()
- def get_prompt_token_ids(self) -> List[int]:
- return self.data.get_prompt_token_ids()
- def get_last_token_id(self) -> int:
- return self.data.get_last_token_id()
- def get_output_token_ids(self) -> List[int]:
- return self.data.output_token_ids
- def get_cumulative_logprob(self) -> float:
- return self.data.cumulative_logprob
- def get_beam_search_score(
- self,
- length_penalty: float = 1.0,
- seq_len: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- ) -> float:
- """Calculate the beam search score with length penalty.
- Adapted from
- https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
- """
- if seq_len is None:
- seq_len = self.get_len()
- # NOTE: HF implementation does not count the EOS token
- # towards the length, we align with that here for testing.
- if (eos_token_id is not None
- and self.get_last_token_id() == eos_token_id):
- seq_len -= 1
- return self.get_cumulative_logprob() / (seq_len**length_penalty)
- def is_finished(self) -> bool:
- return SequenceStatus.is_finished(self.status)
- def fork(self, new_seq_id: int) -> "Sequence":
- seq_group = self.seq_group
- self.seq_group = None
- new_seq = copy.deepcopy(self)
- self.seq_group = seq_group
- new_seq.seq_group = seq_group
- new_seq.seq_id = new_seq_id
- return new_seq
- def get_num_new_tokens(self) -> int:
- """Get the number of new tokens to be computed.
- Args:
- remainig_token_budget: The remaining token budgets.
- Returns:
- The new number of tokens to be computed. I.e., 1 for decode, prompt
- size for prefill. If there's not enough remainig_token_budget, it
- can return the chunked number of new tokens.
- """
- if self.data.stage == SequenceStage.DECODE:
- return 1
- return self.data.get_num_uncomputed_tokens()
- def is_prefill(self) -> bool:
- return self.data.stage == SequenceStage.PREFILL
- def __repr__(self) -> str:
- return (f"Sequence(seq_id={self.seq_id}, "
- f"status={self.status.name}, "
- f"num_blocks={len(self.logical_token_blocks)})")
- @dataclass
- class SequenceGroupState:
- """Mutable state tied to a specific sequence group"""
- # torch.Generator used in seeded sampling
- generator: Optional = None
- class MultiModalData:
- """Multi modal request.
-
- Args:
- type: The data type.
- data: The actual data.
- The required shape and semantic meaning of it depends on the vision
- language config of the hosted model.
- See `VisionLanguageConfig` in `config.py`.
- """
- class Type(enum.Enum):
- IMAGE = enum.auto()
- def __init__(self, type: Type, data: "torch.Tensor"):
- self.type = type
- self.data = data
- class SequenceGroup:
- """A group of sequences that are generated from the same prompt.
- Args:
- request_id: The ID of the request.
- seqs: The list of sequences.
- sampling_params: The sampling parameters used to generate the outputs.
- arrival_time: The arrival time of the request.
- lora_request: LoRA request.
- multi_modal_requst: Multi modal data for the request.
- """
- def __init__(
- self,
- request_id: str,
- seqs: List[Sequence],
- sampling_params: SamplingParams,
- arrival_time: float,
- lora_request: Optional[LoRARequest] = None,
- multi_modal_data: Optional[MultiModalData] = None,
- ) -> None:
- self.request_id = request_id
- self.seqs_dict = {seq.seq_id: seq for seq in seqs}
- self.sampling_params = sampling_params
- self.metrics = RequestMetrics(
- arrival_time=arrival_time,
- last_token_time=arrival_time,
- first_scheduled_time=None,
- first_token_time=None,
- time_in_queue=None,
- )
- self.lora_request = lora_request
- self.prompt_logprobs: Optional[PromptLogprobs] = None
- self.state = SequenceGroupState()
- self.multi_modal_data = multi_modal_data
- # Sequence group should have only one sequence when init.
- seqs[0].seq_group = self
- self.root_seq_id = seqs[0].seq_id
- @property
- def prompt(self) -> str:
- # All sequences in the group should have the same prompt.
- # We use the prompt of an arbitrary sequence.
- return next(iter(self.seqs_dict.values())).prompt
- @property
- def prompt_token_ids(self) -> List[int]:
- # All sequences in the group should have the same prompt.
- # We use the prompt of an arbitrary sequence.
- return next(iter(self.seqs_dict.values())).data.prompt_token_ids
- @property
- def lora_int_id(self) -> int:
- return self.lora_request.lora_int_id if self.lora_request else 0
- def get_last_latency(self, now: float) -> float:
- """Gets last token latency for Request level timings."""
- latency = now - self.metrics.last_token_time
- self.metrics.last_token_time = now
- return latency
- def maybe_set_first_token_time(self, time: float) -> None:
- """Sets the first token time for Request level timings."""
- if self.metrics.first_token_time is None:
- self.metrics.first_token_time = time
- def maybe_set_first_scheduled_time(self, time: float) -> None:
- """Sets the first scheduled time and time in queue for Request level
- timings."""
- if self.metrics.first_scheduled_time is None:
- self.metrics.first_scheduled_time = time
- self.metrics.time_in_queue = time - self.metrics.arrival_time
- def set_finished_time(self, time: Optional[float]) -> None:
- """Sets the finished time for Request level timings."""
- self.metrics.finished_time = time
- def get_max_num_running_seqs(self) -> int:
- """The maximum number of sequences running in parallel in the remaining
- lifetime of the request."""
- if self.sampling_params.use_beam_search:
- # For beam search, maximally there will always be `best_of` beam
- # candidates running in the future.
- return self.sampling_params.best_of
- else:
- if self.sampling_params.best_of > self.num_seqs():
- # At prompt stage, the sequence group is not yet filled up
- # and only have one sequence running. However, in the
- # generation stage, we will have `best_of` sequences running.
- return self.sampling_params.best_of
- # At sampling stages, return the number of actual sequences
- # that are not finished yet.
- return self.num_unfinished_seqs()
- def get_seqs(
- self,
- status: Optional[SequenceStatus] = None,
- ) -> List[Sequence]:
- return (list(self.seqs_dict.values()) if status is None else [
- seq for seq in self.seqs_dict.values() if seq.status == status
- ])
- def get_unfinished_seqs(self) -> List[Sequence]:
- return [
- seq for seq in self.seqs_dict.values() if not seq.is_finished()
- ]
- def get_finished_seqs(self) -> List[Sequence]:
- return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
- def update_num_computed_tokens(self, num_new_computed_tokens: int):
- """Update number of tokens computed so far."""
- for seq in self.seqs_dict.values():
- if not seq.is_finished():
- seq.data.update_num_computed_tokens(num_new_computed_tokens)
- def get_num_uncomputed_tokens(self) -> int:
- num_uncomputed_tokens = 0
- for seq in self.get_seqs():
- if not seq.is_finished():
- num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
- return num_uncomputed_tokens
- def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
- return len(self.get_seqs(status))
- def num_unfinished_seqs(self) -> int:
- return len(self.get_unfinished_seqs())
- def num_finished_seqs(self) -> int:
- return len(self.get_finished_seqs())
- def find(self, seq_id: int) -> Sequence:
- if seq_id not in self.seqs_dict:
- raise ValueError(f"Sequence {seq_id} not found.")
- return self.seqs_dict[seq_id]
- def add(self, seq: Sequence) -> None:
- if seq.seq_id in self.seqs_dict:
- raise ValueError(f"Sequence {seq.seq_id} already exists.")
- seq.seq_group = self
- seq.inner_id = len(self.seqs_dict)
- seq.data.inner_id = seq.inner_id
- self.seqs_dict[seq.seq_id] = seq
- def remove(self, seq_id: int) -> None:
- if seq_id not in self.seqs_dict:
- raise ValueError(f"Sequence {seq_id} not found.")
- del self.seqs_dict[seq_id]
- def is_finished(self) -> bool:
- return all(seq.is_finished() for seq in self.get_seqs())
- def is_prefill(self) -> bool:
- # Every sequences should be in the same stage.
- return self.get_seqs()[0].is_prefill()
- def get_root(self) -> Sequence:
- return self.find(self.root_seq_id)
- def __repr__(self) -> str:
- return (f"SequenceGroup(request_id={self.request_id}, "
- f"sampling_params={self.sampling_params}, "
- f"num_seqs={len(self.seqs_dict)})")
- class SequenceGroupMetadata:
- """Metadata for a sequence group. Used to create `AttentionMetadata`.
- Args:
- request_id: The ID of the request.
- is_prompt: Whether the request is at prompt stage.
- seq_data: The sequence data. (Seq id -> sequence data)
- sampling_params: The sampling parameters used to generate the outputs.
- block_tables: The block tables. (Seq id -> list of physical block
- numbers)
- token_chunk_size: The number of tokens to be processed (per sequence).
- None if chunking is not required.
- state: Internal state tied to this sequence group.
- lora_request: LoRA request.
- multi_modal_data: Multi modal data for the request.
- persistent_data: The persistent data of the sequence group.
- """
- def __init__(
- self,
- request_id: str,
- is_prompt: bool,
- seq_data: Dict[int, SequenceData],
- sampling_params: SamplingParams,
- block_tables: Dict[int, List[int]],
- persistent_data: Dict[int, dict],
- token_chunk_size: Optional[int] = None,
- lora_request: Optional[LoRARequest] = None,
- computed_block_nums: Optional[List[int]] = None,
- state: Optional[SequenceGroupState] = None,
- multi_modal_data: Optional[MultiModalData] = None,
- root_seq_id: Optional[int] = None,
- ) -> None:
- self.request_id = request_id
- self.is_prompt = is_prompt
- self.seq_data = seq_data
- self.sampling_params = sampling_params
- self.block_tables = block_tables
- self.persistent_data = persistent_data
- self.lora_request = lora_request
- self.computed_block_nums = computed_block_nums
- self.state = SequenceGroupState() if state is None else state
- self.multi_modal_data = multi_modal_data
- self._token_chunk_size = token_chunk_size
- self.root_seq_id = root_seq_id
- if self._token_chunk_size is None:
- if is_prompt:
- self._token_chunk_size = list(seq_data.values())[0].get_len()
- else:
- self._token_chunk_size = 1
- @property
- def lora_int_id(self) -> int:
- return self.lora_request.lora_int_id if self.lora_request else 0
- @property
- def token_chunk_size(self) -> int:
- """Return the number of tokens to be processed (chunk size)."""
- return self._token_chunk_size
- class SequenceOutput:
- """The model output associated with a sequence.
- Args:
- parent_seq_id: The ID of the parent sequence (for forking in beam
- search).
- output_token: The output token ID.
- logprobs: The logprobs of the output token.
- (Token id -> logP(x_i+1 | x_0, ..., x_i))
- persistent_data: The persistent data of the sequence.
- """
- def __init__(
- self,
- parent_seq_id: int,
- output_token: int,
- logprobs: Dict[int, Logprob],
- persistent_data: dict,
- ) -> None:
- self.parent_seq_id = parent_seq_id
- self.output_token = output_token
- self.logprobs = logprobs
- self.persistent_data = persistent_data
- def __repr__(self) -> str:
- return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
- f"output_token={self.output_token}, "
- f"logprobs={self.logprobs}, "
- f"persistent_data={self.persistent_data})")
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, SequenceOutput):
- raise NotImplementedError()
- equal = (self.parent_seq_id == other.parent_seq_id
- and self.output_token == other.output_token)
- log_probs_equal = other.logprobs == self.logprobs
- return equal and log_probs_equal
- class SequenceGroupOutput:
- """The model output associated with a sequence group."""
- def __init__(
- self,
- samples: List[SequenceOutput],
- prompt_logprobs: Optional[PromptLogprobs],
- ) -> None:
- self.samples = samples
- self.prompt_logprobs = prompt_logprobs
- def __repr__(self) -> str:
- return (f"SequenceGroupOutput(samples={self.samples}, "
- f"prompt_logprobs={self.prompt_logprobs})")
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, SequenceGroupOutput):
- raise NotImplementedError()
- return (self.samples == other.samples
- and self.prompt_logprobs == other.prompt_logprobs)
- @dataclass
- class SamplerOutput:
- """For each sequence group, we generate a list of SequenceOutput object,
- each of which contains one possible candidate for the next token.
- This datastructure implements methods so it can be used like a list, but
- also has optional fields for device tensors.
- """
- outputs: List[SequenceGroupOutput]
- # On-device tensor containing probabilities of each token.
- sampled_token_probs: Optional["torch.Tensor"] = None
- # On-device tensor containing the sampled token ids.
- sampled_token_ids: Optional["torch.Tensor"] = None
- # Spec decode metrics populated by workers.
- spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
- def __getitem__(self, idx: int):
- return self.outputs[idx]
- def __setitem__(self, idx: int, value):
- self.outputs[idx] = value
- def __len__(self):
- return len(self.outputs)
- def __eq__(self, other: object):
- return (isinstance(other, self.__class__)
- and self.outputs == other.outputs)
- def __repr__(self) -> str:
- """Show the shape of a tensor instead of its values to reduce noise.
- """
- sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
- else self.sampled_token_probs.shape)
- sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
- self.sampled_token_ids.shape)
- return (
- f"SamplerOutput(outputs={self.outputs}, "
- f"sampled_token_probs={sampled_token_probs_repr}, "
- f"sampled_token_ids={sampled_token_ids_repr}, "
- f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|