|
@@ -7,6 +7,7 @@ from collections import defaultdict
|
|
|
from dataclasses import dataclass, field
|
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast
|
|
|
|
|
|
+import numpy
|
|
|
import torch
|
|
|
|
|
|
from aphrodite.common.pooling_params import PoolingParams
|
|
@@ -474,6 +475,19 @@ class Sequence:
|
|
|
f"num_blocks={self.n_blocks}, ")
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class SequenceGroupState:
|
|
|
+ """Mutable state tied to a specific sequence group"""
|
|
|
+
|
|
|
+ # for multi-step decoding
|
|
|
+ num_steps: int = 1
|
|
|
+ current_step: int = 0
|
|
|
+
|
|
|
+ @property
|
|
|
+ def remaining_steps(self) -> int:
|
|
|
+ return self.num_steps - self.current_step
|
|
|
+
|
|
|
+
|
|
|
class SequenceGroup:
|
|
|
"""A group of sequences that are generated from the same prompt.
|
|
|
|
|
@@ -516,6 +530,7 @@ class SequenceGroup:
|
|
|
time_in_queue=None)
|
|
|
self.lora_request = lora_request
|
|
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
|
|
+ self.state = SequenceGroupState()
|
|
|
self.embeddings = embeddings
|
|
|
self.pooling_params = pooling_params
|
|
|
self.prompt_adapter_request = prompt_adapter_request
|
|
@@ -569,6 +584,10 @@ class SequenceGroup:
|
|
|
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
|
|
if self.prompt_adapter_request else 0
|
|
|
|
|
|
+ def init_multi_step(self, num_scheduler_steps: int) -> None:
|
|
|
+ self.state.num_steps = num_scheduler_steps
|
|
|
+ self.state.current_step = 0
|
|
|
+
|
|
|
def get_last_latency(self, now: float) -> Optional[float]:
|
|
|
"""Sets the last token time for Request level timings."""
|
|
|
# If still in prefill phase, raise Error.
|
|
@@ -735,6 +754,7 @@ class SequenceGroupMetadata:
|
|
|
token_chunk_size: The number of tokens to be processed (per sequence).
|
|
|
None if chunking is not required.
|
|
|
lora_request: LoRA request.
|
|
|
+ state: Internal state tied to this sequence group.
|
|
|
computed_block_nums: The block numbers that are already computed,
|
|
|
used in prefix caching.
|
|
|
multi_modal_data: Multi modal data.
|
|
@@ -762,6 +782,7 @@ class SequenceGroupMetadata:
|
|
|
token_chunk_size: Optional[int] = None,
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
computed_block_nums: Optional[List[int]] = None,
|
|
|
+ state: Optional[SequenceGroupState] = None,
|
|
|
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
|
|
encoder_seq_data: Optional[SequenceData] = None,
|
|
|
cross_block_table: Optional[List[int]] = None,
|
|
@@ -777,6 +798,7 @@ class SequenceGroupMetadata:
|
|
|
self.prompt_adapter_request = prompt_adapter_request
|
|
|
self.computed_block_nums = computed_block_nums
|
|
|
self.multi_modal_data = multi_modal_data
|
|
|
+ self.state = SequenceGroupState() if state is None else state
|
|
|
self.encoder_seq_data = encoder_seq_data
|
|
|
self.cross_block_table = cross_block_table
|
|
|
self._token_chunk_size = token_chunk_size
|
|
@@ -815,6 +837,10 @@ class SequenceGroupMetadata:
|
|
|
assert self._token_chunk_size is not None
|
|
|
return self._token_chunk_size
|
|
|
|
|
|
+ def finish_step(self) -> None:
|
|
|
+ assert self.state.current_step < self.state.num_steps
|
|
|
+ self.state.current_step += 1
|
|
|
+
|
|
|
|
|
|
class SequenceOutput:
|
|
|
"""The model output associated with a sequence.
|
|
@@ -952,6 +978,7 @@ class SamplerOutput:
|
|
|
|
|
|
# On-device tensor containing the sampled token ids.
|
|
|
sampled_token_ids: Optional[torch.Tensor] = None
|
|
|
+ sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
|
|
|
|
|
# Spec decode metrics populated by workers.
|
|
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
@@ -1086,6 +1113,33 @@ class ExecuteModelRequest:
|
|
|
num_steps: int = 1
|
|
|
# Finished request ids since last step.
|
|
|
finished_requests_ids: List[str] = field(default_factory=list)
|
|
|
+ # The last sampled token ids for multi step decoding.
|
|
|
+ last_sampled_token_ids: Optional[torch.Tensor] = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_first_multi_step(self) -> bool:
|
|
|
+ # TODO: make this be able to handle batches with variable number of
|
|
|
+ # steps
|
|
|
+ assert len(self.seq_group_metadata_list) > 0
|
|
|
+ first_seq_group = self.seq_group_metadata_list[0]
|
|
|
+ return first_seq_group.state.current_step == 0
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_last_step(self) -> bool:
|
|
|
+ # TODO: make this be able to handle batches with variable number of
|
|
|
+ # steps
|
|
|
+ assert len(self.seq_group_metadata_list) > 0
|
|
|
+ first_seq_group = self.seq_group_metadata_list[0]
|
|
|
+ num_steps = first_seq_group.state.num_steps
|
|
|
+ current_step = first_seq_group.state.current_step
|
|
|
+ return num_steps - current_step == 1
|
|
|
+
|
|
|
+ @property
|
|
|
+ def current_step(self) -> int:
|
|
|
+ # TODO: make this be able to handle batches with variable number of
|
|
|
+ # steps
|
|
|
+ assert len(self.seq_group_metadata_list) > 0
|
|
|
+ return self.seq_group_metadata_list[0].state.current_step
|
|
|
|
|
|
def clone(
|
|
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
@@ -1102,4 +1156,5 @@ class ExecuteModelRequest:
|
|
|
previous_hidden_states=self.previous_hidden_states,
|
|
|
num_steps=self.num_steps,
|
|
|
finished_requests_ids=self.finished_requests_ids,
|
|
|
- )
|
|
|
+ last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
|
|
+ if self.last_sampled_token_ids is not None else None)
|