|
@@ -1,21 +1,30 @@
|
|
"""Sequence and its related classes."""
|
|
"""Sequence and its related classes."""
|
|
import copy
|
|
import copy
|
|
import enum
|
|
import enum
|
|
-from dataclasses import dataclass
|
|
|
|
-from typing import Dict, List, Optional, Union, TYPE_CHECKING
|
|
|
|
|
|
+import math
|
|
|
|
+from abc import ABC, abstractmethod
|
|
|
|
+from array import array
|
|
|
|
+from collections import defaultdict
|
|
|
|
+from dataclasses import dataclass, field
|
|
|
|
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
|
|
-from aphrodite.common.block import LogicalTokenBlock
|
|
|
|
|
|
+import torch
|
|
|
|
+
|
|
|
|
+from aphrodite.common.pooling_params import PoolingParams
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
from aphrodite.lora.request import LoRARequest
|
|
from aphrodite.lora.request import LoRARequest
|
|
|
|
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
if TYPE_CHECKING:
|
|
- import torch
|
|
|
|
|
|
+ from aphrodite.inputs import LLMInputs
|
|
|
|
+ from aphrodite.multimodal import MultiModalDataDict
|
|
from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
class Logprob:
|
|
class Logprob:
|
|
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
|
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
|
|
|
+
|
|
Attributes:
|
|
Attributes:
|
|
logprob: The logprob of chosen token
|
|
logprob: The logprob of chosen token
|
|
rank: The vocab rank of chosen token (>=1)
|
|
rank: The vocab rank of chosen token (>=1)
|
|
@@ -26,29 +35,28 @@ class Logprob:
|
|
decoded_token: Optional[str] = None
|
|
decoded_token: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
+# {token_id -> logprob} per each sequence group. None if the corresponding
|
|
|
|
+# sequence group doesn't require prompt logprob.
|
|
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
|
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
|
|
|
+# {token_id -> logprob} for each sequence group.
|
|
SampleLogprobs = List[Dict[int, Logprob]]
|
|
SampleLogprobs = List[Dict[int, Logprob]]
|
|
|
|
|
|
|
|
|
|
-class SequenceStatus(enum.Enum):
|
|
|
|
|
|
+class SequenceStatus(enum.IntEnum):
|
|
"""Status of a sequence."""
|
|
"""Status of a sequence."""
|
|
-
|
|
|
|
- WAITING = enum.auto()
|
|
|
|
- RUNNING = enum.auto()
|
|
|
|
- SWAPPED = enum.auto()
|
|
|
|
- FINISHED_STOPPED = enum.auto()
|
|
|
|
- FINISHED_LENGTH_CAPPED = enum.auto()
|
|
|
|
- FINISHED_ABORTED = enum.auto()
|
|
|
|
- FINISHED_IGNORED = enum.auto()
|
|
|
|
|
|
+ WAITING = 0
|
|
|
|
+ RUNNING = 1
|
|
|
|
+ SWAPPED = 2
|
|
|
|
+ # Note: anything after SWAPPED (2) will be considered
|
|
|
|
+ # as a finished status.
|
|
|
|
+ FINISHED_STOPPED = 3
|
|
|
|
+ FINISHED_LENGTH_CAPPED = 4
|
|
|
|
+ FINISHED_ABORTED = 5
|
|
|
|
+ FINISHED_IGNORED = 6
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def is_finished(status: "SequenceStatus") -> bool:
|
|
def is_finished(status: "SequenceStatus") -> bool:
|
|
- return status in [
|
|
|
|
- SequenceStatus.FINISHED_STOPPED,
|
|
|
|
- SequenceStatus.FINISHED_LENGTH_CAPPED,
|
|
|
|
- SequenceStatus.FINISHED_ABORTED,
|
|
|
|
- SequenceStatus.FINISHED_IGNORED,
|
|
|
|
- ]
|
|
|
|
|
|
+ return status > SequenceStatus.SWAPPED
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
|
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
|
|
@@ -77,14 +85,13 @@ class SequenceStage(enum.Enum):
|
|
class RequestMetrics:
|
|
class RequestMetrics:
|
|
"""Metrics associated with a request.
|
|
"""Metrics associated with a request.
|
|
|
|
|
|
- Args:
|
|
|
|
|
|
+ Attributes:
|
|
arrival_time: The time when the request arrived.
|
|
arrival_time: The time when the request arrived.
|
|
first_scheduled_time: The time when the request was first scheduled.
|
|
first_scheduled_time: The time when the request was first scheduled.
|
|
first_token_time: The time when the first token was generated.
|
|
first_token_time: The time when the first token was generated.
|
|
time_in_queue: The time the request spent in the queue.
|
|
time_in_queue: The time the request spent in the queue.
|
|
finished_time: The time when the request was finished.
|
|
finished_time: The time when the request was finished.
|
|
"""
|
|
"""
|
|
-
|
|
|
|
arrival_time: float
|
|
arrival_time: float
|
|
last_token_time: float
|
|
last_token_time: float
|
|
first_scheduled_time: Optional[float]
|
|
first_scheduled_time: Optional[float]
|
|
@@ -112,31 +119,76 @@ class SequenceData:
|
|
prompt_token_ids: List[int],
|
|
prompt_token_ids: List[int],
|
|
output_token_ids: Optional[List[int]] = None,
|
|
output_token_ids: Optional[List[int]] = None,
|
|
) -> None:
|
|
) -> None:
|
|
- if output_token_ids is None:
|
|
|
|
- output_token_ids = []
|
|
|
|
|
|
+ self._prompt_token_ids = array('l', prompt_token_ids)
|
|
|
|
+ self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
|
|
|
|
+ self._output_token_ids = array(
|
|
|
|
+ 'l', output_token_ids if output_token_ids is not None else [])
|
|
|
|
|
|
- self.prompt_token_ids = prompt_token_ids
|
|
|
|
- self.output_token_ids = output_token_ids
|
|
|
|
self.cumulative_logprob = 0.0
|
|
self.cumulative_logprob = 0.0
|
|
# The number of tokens that are computed (that run against the model).
|
|
# The number of tokens that are computed (that run against the model).
|
|
self._num_computed_tokens = 0
|
|
self._num_computed_tokens = 0
|
|
self._stage: SequenceStage = SequenceStage.PREFILL
|
|
self._stage: SequenceStage = SequenceStage.PREFILL
|
|
|
|
|
|
|
|
+ self._update_cached_all_tokens()
|
|
|
|
+
|
|
|
|
+ def _update_cached_all_tokens(self):
|
|
|
|
+ self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
|
|
|
|
+ self._output_token_ids)
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def prompt_token_ids(self) -> Tuple[int, ...]:
|
|
|
|
+ return self._prompt_token_ids_tuple
|
|
|
|
+
|
|
|
|
+ @prompt_token_ids.setter
|
|
|
|
+ def prompt_token_ids(self, new_prompt_token_ids) -> None:
|
|
|
|
+ self._prompt_token_ids = array('l', new_prompt_token_ids)
|
|
|
|
+ self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
|
|
|
|
+ self._update_cached_all_tokens()
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def prompt_token_ids_array(self) -> array:
|
|
|
|
+ return self._prompt_token_ids
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def output_token_ids(self) -> Tuple[int, ...]:
|
|
|
|
+ return tuple(self._output_token_ids)
|
|
|
|
+
|
|
|
|
+ @output_token_ids.setter
|
|
|
|
+ def output_token_ids(self, new_output_token_ids) -> None:
|
|
|
|
+ self._output_token_ids = array('l', new_output_token_ids)
|
|
|
|
+ self._update_cached_all_tokens()
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def output_token_ids_array(self) -> array:
|
|
|
|
+ return self._output_token_ids
|
|
|
|
+
|
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
|
- self.output_token_ids.append(token_id)
|
|
|
|
|
|
+ self._output_token_ids.append(token_id)
|
|
|
|
+ self._cached_all_token_ids.append(token_id)
|
|
self.cumulative_logprob += logprob
|
|
self.cumulative_logprob += logprob
|
|
|
|
|
|
def get_len(self) -> int:
|
|
def get_len(self) -> int:
|
|
- return len(self.output_token_ids) + len(self.prompt_token_ids)
|
|
|
|
|
|
+ return len(self._output_token_ids) + len(self._prompt_token_ids)
|
|
|
|
|
|
def get_prompt_len(self) -> int:
|
|
def get_prompt_len(self) -> int:
|
|
- return len(self.prompt_token_ids)
|
|
|
|
|
|
+ return len(self._prompt_token_ids)
|
|
|
|
|
|
def get_output_len(self) -> int:
|
|
def get_output_len(self) -> int:
|
|
- return len(self.output_token_ids)
|
|
|
|
|
|
+ return len(self._output_token_ids)
|
|
|
|
|
|
def get_token_ids(self) -> List[int]:
|
|
def get_token_ids(self) -> List[int]:
|
|
- return self.prompt_token_ids + self.output_token_ids
|
|
|
|
|
|
+ return self._cached_all_token_ids
|
|
|
|
+
|
|
|
|
+ def get_prefix_token_ids(
|
|
|
|
+ self, num_tokens: int
|
|
|
|
+ ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
|
|
|
|
+ """Get prefix tokens, and make the return value hashable"""
|
|
|
|
+ prompt_length = self.get_prompt_len()
|
|
|
|
+ if num_tokens > prompt_length:
|
|
|
|
+ return (self._prompt_token_ids_tuple,
|
|
|
|
+ tuple(self._output_token_ids[:num_tokens - prompt_length]))
|
|
|
|
+ else:
|
|
|
|
+ return (self._prompt_token_ids_tuple[:num_tokens], None)
|
|
|
|
|
|
def get_num_computed_tokens(self) -> int:
|
|
def get_num_computed_tokens(self) -> int:
|
|
"""Return the number of prefill tokens that are already computed."""
|
|
"""Return the number of prefill tokens that are already computed."""
|
|
@@ -160,21 +212,21 @@ class SequenceData:
|
|
self._stage = SequenceStage.PREFILL
|
|
self._stage = SequenceStage.PREFILL
|
|
|
|
|
|
def get_num_uncomputed_tokens(self) -> int:
|
|
def get_num_uncomputed_tokens(self) -> int:
|
|
- """Return the number of prefil tokens that are not computed."""
|
|
|
|
|
|
+ """Return the number of prefill tokens that are not computed."""
|
|
# we use `get_len()` which includes prompt_len + output_len instead
|
|
# we use `get_len()` which includes prompt_len + output_len instead
|
|
# of prompt_len here. This is because during recompute we need to
|
|
# of prompt_len here. This is because during recompute we need to
|
|
# prefill for both prompt and output.
|
|
# prefill for both prompt and output.
|
|
return self.get_len() - self.get_num_computed_tokens()
|
|
return self.get_len() - self.get_num_computed_tokens()
|
|
|
|
|
|
def get_last_token_id(self) -> int:
|
|
def get_last_token_id(self) -> int:
|
|
- if not self.output_token_ids:
|
|
|
|
- return self.prompt_token_ids[-1]
|
|
|
|
- return self.output_token_ids[-1]
|
|
|
|
|
|
+ if not self._output_token_ids:
|
|
|
|
+ return self._prompt_token_ids[-1]
|
|
|
|
+ return self._output_token_ids[-1]
|
|
|
|
|
|
- def get_prompt_token_ids(self) -> int:
|
|
|
|
|
|
+ def get_prompt_token_ids(self) -> Tuple[int, ...]:
|
|
return self.prompt_token_ids
|
|
return self.prompt_token_ids
|
|
|
|
|
|
- def get_output_token_ids(self) -> int:
|
|
|
|
|
|
+ def get_output_token_ids(self) -> Tuple[int, ...]:
|
|
return self.output_token_ids
|
|
return self.output_token_ids
|
|
|
|
|
|
@property
|
|
@property
|
|
@@ -183,8 +235,8 @@ class SequenceData:
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
return (f"SequenceData("
|
|
return (f"SequenceData("
|
|
- f"prompt_token_ids={self.prompt_token_ids}, "
|
|
|
|
- f"output_token_ids={self.output_token_ids}, "
|
|
|
|
|
|
+ f"prompt_token_ids={self._prompt_token_ids}, "
|
|
|
|
+ f"output_token_ids={self._output_token_ids}, "
|
|
f"cumulative_logprob={self.cumulative_logprob})")
|
|
f"cumulative_logprob={self.cumulative_logprob})")
|
|
|
|
|
|
|
|
|
|
@@ -193,35 +245,33 @@ class Sequence:
|
|
|
|
|
|
Args:
|
|
Args:
|
|
seq_id: The ID of the sequence.
|
|
seq_id: The ID of the sequence.
|
|
- prompt: The prompt of the sequence.
|
|
|
|
- prompt_token_ids: The token IDs of the prompt.
|
|
|
|
|
|
+ inputs: The inputs of the sequence.
|
|
block_size: The block size of the sequence. Should be the same as the
|
|
block_size: The block size of the sequence. Should be the same as the
|
|
block size used by the block manager and cache engine.
|
|
block size used by the block manager and cache engine.
|
|
lora_request: LoRA request.
|
|
lora_request: LoRA request.
|
|
|
|
+ prompt_adapter_request: Prompt adapter request.
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
- self,
|
|
|
|
- seq_id: int,
|
|
|
|
- prompt: str,
|
|
|
|
- prompt_token_ids: List[int],
|
|
|
|
- block_size: int,
|
|
|
|
- eos_token_id: Optional[int] = None,
|
|
|
|
- lora_request: Optional[LoRARequest] = None,
|
|
|
|
|
|
+ self,
|
|
|
|
+ seq_id: int,
|
|
|
|
+ inputs: "LLMInputs",
|
|
|
|
+ block_size: int,
|
|
|
|
+ eos_token_id: Optional[int] = None,
|
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
|
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
|
) -> None:
|
|
) -> None:
|
|
self.seq_id = seq_id
|
|
self.seq_id = seq_id
|
|
- self.prompt = prompt
|
|
|
|
|
|
+ self.inputs = inputs
|
|
self.block_size = block_size
|
|
self.block_size = block_size
|
|
self.eos_token_id = eos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
self.lora_request = lora_request
|
|
self.lora_request = lora_request
|
|
|
|
+ self.prompt_adapter_request = prompt_adapter_request
|
|
|
|
|
|
- self.data = SequenceData(prompt_token_ids)
|
|
|
|
|
|
+ self.data = SequenceData(self.prompt_token_ids)
|
|
self.output_logprobs: SampleLogprobs = []
|
|
self.output_logprobs: SampleLogprobs = []
|
|
self.output_text = ""
|
|
self.output_text = ""
|
|
|
|
|
|
- self.logical_token_blocks: List[LogicalTokenBlock] = []
|
|
|
|
- # Initialize the logical token blocks with the prompt token ids.
|
|
|
|
- self._append_tokens_to_blocks(prompt_token_ids)
|
|
|
|
self.status = SequenceStatus.WAITING
|
|
self.status = SequenceStatus.WAITING
|
|
self.stop_reason: Union[int, str, None] = None
|
|
self.stop_reason: Union[int, str, None] = None
|
|
|
|
|
|
@@ -230,12 +280,32 @@ class Sequence:
|
|
self.read_offset = 0
|
|
self.read_offset = 0
|
|
# Input + output tokens
|
|
# Input + output tokens
|
|
self.tokens: Optional[List[str]] = None
|
|
self.tokens: Optional[List[str]] = None
|
|
- self.persistent_data = {}
|
|
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def n_blocks(self) -> int:
|
|
|
|
+ return math.ceil(self.get_len() / self.block_size)
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def prompt(self) -> Optional[str]:
|
|
|
|
+ return self.inputs.get("prompt")
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def prompt_token_ids(self) -> List[int]:
|
|
|
|
+ return self.inputs["prompt_token_ids"]
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
|
|
|
|
+ return self.inputs.get("multi_modal_data")
|
|
|
|
|
|
@property
|
|
@property
|
|
def lora_int_id(self) -> int:
|
|
def lora_int_id(self) -> int:
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def prompt_adapter_id(self) -> int:
|
|
|
|
+ return self.prompt_adapter_request.prompt_adapter_id \
|
|
|
|
+ if self.prompt_adapter_request else 0
|
|
|
|
+
|
|
def get_output_text_to_return(self, buffer_length: int):
|
|
def get_output_text_to_return(self, buffer_length: int):
|
|
# We return the full output text if the sequence is finished.
|
|
# We return the full output text if the sequence is finished.
|
|
truncate = buffer_length and not self.is_finished()
|
|
truncate = buffer_length and not self.is_finished()
|
|
@@ -243,12 +313,14 @@ class Sequence:
|
|
self.output_text)
|
|
self.output_text)
|
|
|
|
|
|
def hash_of_block(self, logical_idx: int) -> int:
|
|
def hash_of_block(self, logical_idx: int) -> int:
|
|
|
|
+ # TODO This can produce incorrect hash when block size > prompt size
|
|
|
|
+
|
|
# Compute the number of tokens in the sequence
|
|
# Compute the number of tokens in the sequence
|
|
# TODO: The current hashing function is O(L^2). We should optimize
|
|
# TODO: The current hashing function is O(L^2). We should optimize
|
|
# this in the future.
|
|
# this in the future.
|
|
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
|
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
|
- return hash(
|
|
|
|
- (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
|
|
|
|
|
|
+ hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
|
|
|
|
+ return hash((hashed_tokens, self.lora_int_id))
|
|
|
|
|
|
def num_hashed_tokens_of_block(self, logical_idx: int):
|
|
def num_hashed_tokens_of_block(self, logical_idx: int):
|
|
return logical_idx * self.block_size + self.block_size
|
|
return logical_idx * self.block_size + self.block_size
|
|
@@ -257,36 +329,12 @@ class Sequence:
|
|
"""Reset the sequence states for recomputation."""
|
|
"""Reset the sequence states for recomputation."""
|
|
self.data.reset_state_for_recompute()
|
|
self.data.reset_state_for_recompute()
|
|
|
|
|
|
- def _append_logical_block(self) -> None:
|
|
|
|
- block = LogicalTokenBlock(
|
|
|
|
- block_number=len(self.logical_token_blocks),
|
|
|
|
- block_size=self.block_size,
|
|
|
|
- )
|
|
|
|
- self.logical_token_blocks.append(block)
|
|
|
|
-
|
|
|
|
- def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
|
|
|
- cursor = 0
|
|
|
|
- while cursor < len(token_ids):
|
|
|
|
- if not self.logical_token_blocks:
|
|
|
|
- self._append_logical_block()
|
|
|
|
-
|
|
|
|
- last_block = self.logical_token_blocks[-1]
|
|
|
|
- if last_block.is_full():
|
|
|
|
- self._append_logical_block()
|
|
|
|
- last_block = self.logical_token_blocks[-1]
|
|
|
|
-
|
|
|
|
- num_empty_slots = last_block.get_num_empty_slots()
|
|
|
|
- last_block.append_tokens(token_ids[cursor:cursor +
|
|
|
|
- num_empty_slots])
|
|
|
|
- cursor += num_empty_slots
|
|
|
|
-
|
|
|
|
def append_token_id(
|
|
def append_token_id(
|
|
self,
|
|
self,
|
|
token_id: int,
|
|
token_id: int,
|
|
logprobs: Dict[int, Logprob],
|
|
logprobs: Dict[int, Logprob],
|
|
) -> None:
|
|
) -> None:
|
|
assert token_id in logprobs
|
|
assert token_id in logprobs
|
|
- self._append_tokens_to_blocks([token_id])
|
|
|
|
self.output_logprobs.append(logprobs)
|
|
self.output_logprobs.append(logprobs)
|
|
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
|
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
|
|
|
|
|
@@ -302,24 +350,22 @@ class Sequence:
|
|
def get_token_ids(self) -> List[int]:
|
|
def get_token_ids(self) -> List[int]:
|
|
return self.data.get_token_ids()
|
|
return self.data.get_token_ids()
|
|
|
|
|
|
- def get_prompt_token_ids(self) -> List[int]:
|
|
|
|
|
|
+ def get_prompt_token_ids(self) -> Tuple[int, ...]:
|
|
return self.data.get_prompt_token_ids()
|
|
return self.data.get_prompt_token_ids()
|
|
|
|
|
|
def get_last_token_id(self) -> int:
|
|
def get_last_token_id(self) -> int:
|
|
return self.data.get_last_token_id()
|
|
return self.data.get_last_token_id()
|
|
|
|
|
|
- def get_output_token_ids(self) -> List[int]:
|
|
|
|
- return self.data.output_token_ids
|
|
|
|
|
|
+ def get_output_token_ids(self) -> Tuple[int, ...]:
|
|
|
|
+ return self.data.get_output_token_ids()
|
|
|
|
|
|
def get_cumulative_logprob(self) -> float:
|
|
def get_cumulative_logprob(self) -> float:
|
|
return self.data.cumulative_logprob
|
|
return self.data.cumulative_logprob
|
|
|
|
|
|
- def get_beam_search_score(
|
|
|
|
- self,
|
|
|
|
- length_penalty: float = 1.0,
|
|
|
|
- seq_len: Optional[int] = None,
|
|
|
|
- eos_token_id: Optional[int] = None,
|
|
|
|
- ) -> float:
|
|
|
|
|
|
+ def get_beam_search_score(self,
|
|
|
|
+ length_penalty: float = 1.0,
|
|
|
|
+ seq_len: Optional[int] = None,
|
|
|
|
+ eos_token_id: Optional[int] = None) -> float:
|
|
"""Calculate the beam search score with length penalty.
|
|
"""Calculate the beam search score with length penalty.
|
|
|
|
|
|
Adapted from
|
|
Adapted from
|
|
@@ -345,12 +391,10 @@ class Sequence:
|
|
|
|
|
|
def get_num_new_tokens(self) -> int:
|
|
def get_num_new_tokens(self) -> int:
|
|
"""Get the number of new tokens to be computed.
|
|
"""Get the number of new tokens to be computed.
|
|
- Args:
|
|
|
|
- remainig_token_budget: The remaining token budgets.
|
|
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
- The new number of tokens to be computed. I.e., 1 for decode, prompt
|
|
|
|
- size for prefill. If there's not enough remainig_token_budget, it
|
|
|
|
- can return the chunked number of new tokens.
|
|
|
|
|
|
+ The new number of tokens to be computed. I.e., 1 for decode, or
|
|
|
|
+ the remaining prompt size for prefill.
|
|
"""
|
|
"""
|
|
if self.data.stage == SequenceStage.DECODE:
|
|
if self.data.stage == SequenceStage.DECODE:
|
|
return 1
|
|
return 1
|
|
@@ -362,34 +406,7 @@ class Sequence:
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
return (f"Sequence(seq_id={self.seq_id}, "
|
|
return (f"Sequence(seq_id={self.seq_id}, "
|
|
f"status={self.status.name}, "
|
|
f"status={self.status.name}, "
|
|
- f"num_blocks={len(self.logical_token_blocks)})")
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-@dataclass
|
|
|
|
-class SequenceGroupState:
|
|
|
|
- """Mutable state tied to a specific sequence group"""
|
|
|
|
-
|
|
|
|
- # torch.Generator used in seeded sampling
|
|
|
|
- generator: Optional = None
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class MultiModalData:
|
|
|
|
- """Multi modal request.
|
|
|
|
-
|
|
|
|
- Args:
|
|
|
|
- type: The data type.
|
|
|
|
- data: The actual data.
|
|
|
|
- The required shape and semantic meaning of it depends on the vision
|
|
|
|
- language config of the hosted model.
|
|
|
|
- See `VisionLanguageConfig` in `config.py`.
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- class Type(enum.Enum):
|
|
|
|
- IMAGE = enum.auto()
|
|
|
|
-
|
|
|
|
- def __init__(self, type: Type, data: "torch.Tensor"):
|
|
|
|
- self.type = type
|
|
|
|
- self.data = data
|
|
|
|
|
|
+ f"num_blocks={self.n_blocks}, ")
|
|
|
|
|
|
|
|
|
|
class SequenceGroup:
|
|
class SequenceGroup:
|
|
@@ -401,63 +418,101 @@ class SequenceGroup:
|
|
sampling_params: The sampling parameters used to generate the outputs.
|
|
sampling_params: The sampling parameters used to generate the outputs.
|
|
arrival_time: The arrival time of the request.
|
|
arrival_time: The arrival time of the request.
|
|
lora_request: LoRA request.
|
|
lora_request: LoRA request.
|
|
- multi_modal_requst: Multi modal data for the request.
|
|
|
|
|
|
+ embeddings: The embeddings vectors of the prompt of the sequence group
|
|
|
|
+ for an embedding model.
|
|
|
|
+ pooling_params: The pooling parameters used to generate the pooling
|
|
|
|
+ for an embedding model.
|
|
|
|
+ encoder_seq: Optional, the single encoder sequence. Should be None
|
|
|
|
+ unless you are working with an encoder/decoder model.
|
|
|
|
+ prompt_adapter_request: Prompt adapter request.
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
request_id: str,
|
|
request_id: str,
|
|
seqs: List[Sequence],
|
|
seqs: List[Sequence],
|
|
- sampling_params: SamplingParams,
|
|
|
|
arrival_time: float,
|
|
arrival_time: float,
|
|
|
|
+ sampling_params: Optional[SamplingParams] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
- multi_modal_data: Optional[MultiModalData] = None,
|
|
|
|
|
|
+ embeddings: Optional[List[float]] = None,
|
|
|
|
+ pooling_params: Optional[PoolingParams] = None,
|
|
|
|
+ encoder_seq: Optional[Sequence] = None,
|
|
|
|
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
) -> None:
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.request_id = request_id
|
|
|
|
+ self.seqs = seqs
|
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
|
self.sampling_params = sampling_params
|
|
self.sampling_params = sampling_params
|
|
- self.metrics = RequestMetrics(
|
|
|
|
- arrival_time=arrival_time,
|
|
|
|
- last_token_time=arrival_time,
|
|
|
|
- first_scheduled_time=None,
|
|
|
|
- first_token_time=None,
|
|
|
|
- time_in_queue=None,
|
|
|
|
- )
|
|
|
|
|
|
+ self.metrics = RequestMetrics(arrival_time=arrival_time,
|
|
|
|
+ last_token_time=arrival_time,
|
|
|
|
+ first_scheduled_time=None,
|
|
|
|
+ first_token_time=None,
|
|
|
|
+ time_in_queue=None)
|
|
self.lora_request = lora_request
|
|
self.lora_request = lora_request
|
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
|
- self.state = SequenceGroupState()
|
|
|
|
- self.multi_modal_data = multi_modal_data
|
|
|
|
|
|
+ self.embeddings = embeddings
|
|
|
|
+ self.pooling_params = pooling_params
|
|
|
|
+ self.prompt_adapter_request = prompt_adapter_request
|
|
|
|
+ self.encoder_seq = encoder_seq
|
|
|
|
|
|
@property
|
|
@property
|
|
- def prompt(self) -> str:
|
|
|
|
|
|
+ def prompt(self) -> Optional[str]:
|
|
# All sequences in the group should have the same prompt.
|
|
# All sequences in the group should have the same prompt.
|
|
# We use the prompt of an arbitrary sequence.
|
|
# We use the prompt of an arbitrary sequence.
|
|
- return next(iter(self.seqs_dict.values())).prompt
|
|
|
|
|
|
+ return self.seqs[0].prompt
|
|
|
|
|
|
@property
|
|
@property
|
|
def prompt_token_ids(self) -> List[int]:
|
|
def prompt_token_ids(self) -> List[int]:
|
|
# All sequences in the group should have the same prompt.
|
|
# All sequences in the group should have the same prompt.
|
|
# We use the prompt of an arbitrary sequence.
|
|
# We use the prompt of an arbitrary sequence.
|
|
- return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
|
|
|
|
|
+ return self.seqs[0].prompt_token_ids
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def multi_modal_data(self) -> "MultiModalDataDict":
|
|
|
|
+ # All sequences in the group should have the same multi-modal data.
|
|
|
|
+ # We use the multi-modal data of an arbitrary sequence.
|
|
|
|
+ return self.seqs[0].multi_modal_data
|
|
|
|
|
|
@property
|
|
@property
|
|
def lora_int_id(self) -> int:
|
|
def lora_int_id(self) -> int:
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
|
|
|
|
- def get_last_latency(self, now: float) -> float:
|
|
|
|
- """Gets last token latency for Request level timings."""
|
|
|
|
|
|
+ @property
|
|
|
|
+ def prompt_adapter_id(self) -> int:
|
|
|
|
+ return self.prompt_adapter_request.prompt_adapter_id \
|
|
|
|
+ if self.prompt_adapter_request else 0
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def prompt_adapter_num_virtual_tokens(self) -> int:
|
|
|
|
+ return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
|
|
|
+ if self.prompt_adapter_request else 0
|
|
|
|
+
|
|
|
|
+ def get_last_latency(self, now: float) -> Optional[float]:
|
|
|
|
+ """Sets the last token time for Request level timings."""
|
|
|
|
+ # If still in prefill phase, raise Error.
|
|
|
|
+ if self.is_prefill():
|
|
|
|
+ raise ValueError(
|
|
|
|
+ "seq_group.get_last_latency() should not be called "
|
|
|
|
+ "if the seq_group is in prefill phase.")
|
|
|
|
+
|
|
|
|
+ # Otherwise return token latency.
|
|
latency = now - self.metrics.last_token_time
|
|
latency = now - self.metrics.last_token_time
|
|
self.metrics.last_token_time = now
|
|
self.metrics.last_token_time = now
|
|
return latency
|
|
return latency
|
|
|
|
|
|
def maybe_set_first_token_time(self, time: float) -> None:
|
|
def maybe_set_first_token_time(self, time: float) -> None:
|
|
"""Sets the first token time for Request level timings."""
|
|
"""Sets the first token time for Request level timings."""
|
|
- if self.metrics.first_token_time is None:
|
|
|
|
|
|
+ # NOTE: in a case where a sequence_group is swapped and
|
|
|
|
+ # recomputed, the time between iterations is counted
|
|
|
|
+ # in TPOT, rather than recalculating TTFT (since from the )
|
|
|
|
+ # POV of the user, there is simply a long generation delay.
|
|
|
|
+ if (self.metrics.first_token_time is None
|
|
|
|
+ and self.seqs[0].get_output_len() == 1):
|
|
self.metrics.first_token_time = time
|
|
self.metrics.first_token_time = time
|
|
|
|
|
|
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
|
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
|
- """Sets the first scheduled time and time in queue for Request level
|
|
|
|
- timings."""
|
|
|
|
|
|
+ """Sets the first scheduled time and time in queue for Request
|
|
|
|
+ level timings."""
|
|
if self.metrics.first_scheduled_time is None:
|
|
if self.metrics.first_scheduled_time is None:
|
|
self.metrics.first_scheduled_time = time
|
|
self.metrics.first_scheduled_time = time
|
|
self.metrics.time_in_queue = time - self.metrics.arrival_time
|
|
self.metrics.time_in_queue = time - self.metrics.arrival_time
|
|
@@ -469,12 +524,13 @@ class SequenceGroup:
|
|
def get_max_num_running_seqs(self) -> int:
|
|
def get_max_num_running_seqs(self) -> int:
|
|
"""The maximum number of sequences running in parallel in the remaining
|
|
"""The maximum number of sequences running in parallel in the remaining
|
|
lifetime of the request."""
|
|
lifetime of the request."""
|
|
- if self.sampling_params.use_beam_search:
|
|
|
|
|
|
+ if self.sampling_params and self.sampling_params.use_beam_search:
|
|
# For beam search, maximally there will always be `best_of` beam
|
|
# For beam search, maximally there will always be `best_of` beam
|
|
# candidates running in the future.
|
|
# candidates running in the future.
|
|
return self.sampling_params.best_of
|
|
return self.sampling_params.best_of
|
|
else:
|
|
else:
|
|
- if self.sampling_params.best_of > self.num_seqs():
|
|
|
|
|
|
+ if (self.sampling_params
|
|
|
|
+ and self.sampling_params.best_of > self.num_seqs()):
|
|
# At prompt stage, the sequence group is not yet filled up
|
|
# At prompt stage, the sequence group is not yet filled up
|
|
# and only have one sequence running. However, in the
|
|
# and only have one sequence running. However, in the
|
|
# generation stage, we will have `best_of` sequences running.
|
|
# generation stage, we will have `best_of` sequences running.
|
|
@@ -487,27 +543,31 @@ class SequenceGroup:
|
|
self,
|
|
self,
|
|
status: Optional[SequenceStatus] = None,
|
|
status: Optional[SequenceStatus] = None,
|
|
) -> List[Sequence]:
|
|
) -> List[Sequence]:
|
|
- return (list(self.seqs_dict.values()) if status is None else [
|
|
|
|
- seq for seq in self.seqs_dict.values() if seq.status == status
|
|
|
|
- ])
|
|
|
|
|
|
+ if status is None:
|
|
|
|
+ return self.seqs
|
|
|
|
+ return [seq for seq in self.seqs if seq.status == status]
|
|
|
|
+
|
|
|
|
+ def is_encoder_decoder(self) -> bool:
|
|
|
|
+ return self.encoder_seq is not None
|
|
|
|
+
|
|
|
|
+ def get_encoder_seq(self) -> Optional[Sequence]:
|
|
|
|
+ return self.encoder_seq
|
|
|
|
|
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
|
- return [
|
|
|
|
- seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
|
|
|
- ]
|
|
|
|
|
|
+ return [seq for seq in self.seqs if not seq.is_finished()]
|
|
|
|
|
|
def get_finished_seqs(self) -> List[Sequence]:
|
|
def get_finished_seqs(self) -> List[Sequence]:
|
|
- return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
|
|
|
|
|
+ return [seq for seq in self.seqs if seq.is_finished()]
|
|
|
|
|
|
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
|
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
|
"""Update number of tokens computed so far."""
|
|
"""Update number of tokens computed so far."""
|
|
- for seq in self.seqs_dict.values():
|
|
|
|
|
|
+ for seq in self.seqs:
|
|
if not seq.is_finished():
|
|
if not seq.is_finished():
|
|
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
|
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
|
|
|
|
|
def get_num_uncomputed_tokens(self) -> int:
|
|
def get_num_uncomputed_tokens(self) -> int:
|
|
num_uncomputed_tokens = 0
|
|
num_uncomputed_tokens = 0
|
|
- for seq in self.get_seqs():
|
|
|
|
|
|
+ for seq in self.seqs:
|
|
if not seq.is_finished():
|
|
if not seq.is_finished():
|
|
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
|
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
|
return num_uncomputed_tokens
|
|
return num_uncomputed_tokens
|
|
@@ -516,7 +576,7 @@ class SequenceGroup:
|
|
# Optimization. We don't need to call get_seqs if we don't need to
|
|
# Optimization. We don't need to call get_seqs if we don't need to
|
|
# filter by states.
|
|
# filter by states.
|
|
if status is None:
|
|
if status is None:
|
|
- return len(self.seqs_dict)
|
|
|
|
|
|
+ return len(self.seqs)
|
|
|
|
|
|
return len(self.get_seqs(status))
|
|
return len(self.get_seqs(status))
|
|
|
|
|
|
@@ -535,23 +595,25 @@ class SequenceGroup:
|
|
if seq.seq_id in self.seqs_dict:
|
|
if seq.seq_id in self.seqs_dict:
|
|
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
|
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
|
self.seqs_dict[seq.seq_id] = seq
|
|
self.seqs_dict[seq.seq_id] = seq
|
|
|
|
+ self.seqs.append(seq)
|
|
|
|
|
|
def remove(self, seq_id: int) -> None:
|
|
def remove(self, seq_id: int) -> None:
|
|
- if seq_id not in self.seqs_dict:
|
|
|
|
|
|
+ seq = self.seqs_dict.pop(seq_id, None)
|
|
|
|
+ if seq is None:
|
|
raise ValueError(f"Sequence {seq_id} not found.")
|
|
raise ValueError(f"Sequence {seq_id} not found.")
|
|
- del self.seqs_dict[seq_id]
|
|
|
|
|
|
+ self.seqs.remove(seq)
|
|
|
|
|
|
def is_finished(self) -> bool:
|
|
def is_finished(self) -> bool:
|
|
- return all(seq.is_finished() for seq in self.get_seqs())
|
|
|
|
|
|
+ return all(seq.is_finished() for seq in self.seqs)
|
|
|
|
|
|
def is_prefill(self) -> bool:
|
|
def is_prefill(self) -> bool:
|
|
- # Every sequences should be in the same stage.
|
|
|
|
- return self.get_seqs()[0].is_prefill()
|
|
|
|
|
|
+ # Every sequence should be in the same stage.
|
|
|
|
+ return self.seqs[0].is_prefill()
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
|
f"sampling_params={self.sampling_params}, "
|
|
f"sampling_params={self.sampling_params}, "
|
|
- f"num_seqs={len(self.seqs_dict)})")
|
|
|
|
|
|
+ f"num_seqs={len(self.seqs)})")
|
|
|
|
|
|
|
|
|
|
class SequenceGroupMetadata:
|
|
class SequenceGroupMetadata:
|
|
@@ -564,12 +626,25 @@ class SequenceGroupMetadata:
|
|
sampling_params: The sampling parameters used to generate the outputs.
|
|
sampling_params: The sampling parameters used to generate the outputs.
|
|
block_tables: The block tables. (Seq id -> list of physical block
|
|
block_tables: The block tables. (Seq id -> list of physical block
|
|
numbers)
|
|
numbers)
|
|
|
|
+ do_sample: True if sampling is required. Sampling is not required when
|
|
|
|
+ e.g., prefill is chunked, and the current iteration only computes
|
|
|
|
+ query tokens for prefill, we don't need sampling.
|
|
token_chunk_size: The number of tokens to be processed (per sequence).
|
|
token_chunk_size: The number of tokens to be processed (per sequence).
|
|
None if chunking is not required.
|
|
None if chunking is not required.
|
|
- state: Internal state tied to this sequence group.
|
|
|
|
lora_request: LoRA request.
|
|
lora_request: LoRA request.
|
|
- multi_modal_data: Multi modal data for the request.
|
|
|
|
- persistent_data: The persistent data of the sequence group.
|
|
|
|
|
|
+ computed_block_nums: The block numbers that are already computed,
|
|
|
|
+ used in prefix caching.
|
|
|
|
+ multi_modal_data: Multi modal data.
|
|
|
|
+ encoder_seq_data: Optional sequence data for encoder prompt
|
|
|
|
+ (SequenceGroup.encoder_seq). Should be None
|
|
|
|
+ unless you are working with an encoder/decoder
|
|
|
|
+ model.
|
|
|
|
+ cross_block_table: Optional cross-attention block table associated
|
|
|
|
+ with the encoder prompt
|
|
|
|
+ (SequenceGroup.encoder_seq). Should be None
|
|
|
|
+ unless you are working with an encoder/decoder
|
|
|
|
+ model.
|
|
|
|
+ prompt_adapter_request: Prompt Adapter request.
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
@@ -579,24 +654,36 @@ class SequenceGroupMetadata:
|
|
seq_data: Dict[int, SequenceData],
|
|
seq_data: Dict[int, SequenceData],
|
|
sampling_params: SamplingParams,
|
|
sampling_params: SamplingParams,
|
|
block_tables: Dict[int, List[int]],
|
|
block_tables: Dict[int, List[int]],
|
|
- persistent_data: Dict[int, dict],
|
|
|
|
|
|
+ do_sample: bool = True,
|
|
|
|
+ pooling_params: Optional[PoolingParams] = None,
|
|
token_chunk_size: Optional[int] = None,
|
|
token_chunk_size: Optional[int] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
computed_block_nums: Optional[List[int]] = None,
|
|
computed_block_nums: Optional[List[int]] = None,
|
|
- state: Optional[SequenceGroupState] = None,
|
|
|
|
- multi_modal_data: Optional[MultiModalData] = None,
|
|
|
|
|
|
+ multi_modal_data: Optional["MultiModalDataDict"] = None,
|
|
|
|
+ encoder_seq_data: Optional[SequenceData] = None,
|
|
|
|
+ cross_block_table: Optional[List[int]] = None,
|
|
|
|
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
) -> None:
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.request_id = request_id
|
|
self.is_prompt = is_prompt
|
|
self.is_prompt = is_prompt
|
|
self.seq_data = seq_data
|
|
self.seq_data = seq_data
|
|
self.sampling_params = sampling_params
|
|
self.sampling_params = sampling_params
|
|
self.block_tables = block_tables
|
|
self.block_tables = block_tables
|
|
- self.persistent_data = persistent_data
|
|
|
|
|
|
+ self.pooling_params = pooling_params
|
|
self.lora_request = lora_request
|
|
self.lora_request = lora_request
|
|
|
|
+ self.prompt_adapter_request = prompt_adapter_request
|
|
self.computed_block_nums = computed_block_nums
|
|
self.computed_block_nums = computed_block_nums
|
|
- self.state = SequenceGroupState() if state is None else state
|
|
|
|
self.multi_modal_data = multi_modal_data
|
|
self.multi_modal_data = multi_modal_data
|
|
|
|
+ self.encoder_seq_data = encoder_seq_data
|
|
|
|
+ self.cross_block_table = cross_block_table
|
|
self._token_chunk_size = token_chunk_size
|
|
self._token_chunk_size = token_chunk_size
|
|
|
|
+ self.do_sample = do_sample
|
|
|
|
+
|
|
|
|
+ # The number of speculative tokens adopted in this request.
|
|
|
|
+ # None means specuative decoding is not used.
|
|
|
|
+ # Zero means speculative decoding is disabled for some reasons.
|
|
|
|
+ # TODO: We should maintain this states out of the sequence group.
|
|
|
|
+ self.num_speculative_tokens = None
|
|
|
|
|
|
if self._token_chunk_size is None:
|
|
if self._token_chunk_size is None:
|
|
if is_prompt:
|
|
if is_prompt:
|
|
@@ -608,9 +695,20 @@ class SequenceGroupMetadata:
|
|
def lora_int_id(self) -> int:
|
|
def lora_int_id(self) -> int:
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def prompt_adapter_id(self) -> int:
|
|
|
|
+ return self.prompt_adapter_request.prompt_adapter_id \
|
|
|
|
+ if self.prompt_adapter_request else 0
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def prompt_adapter_num_virtual_tokens(self) -> int:
|
|
|
|
+ return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
|
|
|
+ if self.prompt_adapter_request else 0
|
|
|
|
+
|
|
@property
|
|
@property
|
|
def token_chunk_size(self) -> int:
|
|
def token_chunk_size(self) -> int:
|
|
"""Return the number of tokens to be processed (chunk size)."""
|
|
"""Return the number of tokens to be processed (chunk size)."""
|
|
|
|
+ assert self._token_chunk_size is not None
|
|
return self._token_chunk_size
|
|
return self._token_chunk_size
|
|
|
|
|
|
|
|
|
|
@@ -623,7 +721,6 @@ class SequenceOutput:
|
|
output_token: The output token ID.
|
|
output_token: The output token ID.
|
|
logprobs: The logprobs of the output token.
|
|
logprobs: The logprobs of the output token.
|
|
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
|
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
|
- persistent_data: The persistent data of the sequence.
|
|
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
@@ -631,18 +728,15 @@ class SequenceOutput:
|
|
parent_seq_id: int,
|
|
parent_seq_id: int,
|
|
output_token: int,
|
|
output_token: int,
|
|
logprobs: Dict[int, Logprob],
|
|
logprobs: Dict[int, Logprob],
|
|
- persistent_data: dict,
|
|
|
|
) -> None:
|
|
) -> None:
|
|
self.parent_seq_id = parent_seq_id
|
|
self.parent_seq_id = parent_seq_id
|
|
self.output_token = output_token
|
|
self.output_token = output_token
|
|
self.logprobs = logprobs
|
|
self.logprobs = logprobs
|
|
- self.persistent_data = persistent_data
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
|
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
|
f"output_token={self.output_token}, "
|
|
f"output_token={self.output_token}, "
|
|
- f"logprobs={self.logprobs}, "
|
|
|
|
- f"persistent_data={self.persistent_data})")
|
|
|
|
|
|
+ f"logprobs={self.logprobs})")
|
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, SequenceOutput):
|
|
if not isinstance(other, SequenceOutput):
|
|
@@ -653,8 +747,20 @@ class SequenceOutput:
|
|
return equal and log_probs_equal
|
|
return equal and log_probs_equal
|
|
|
|
|
|
|
|
|
|
-class SequenceGroupOutput:
|
|
|
|
- """The model output associated with a sequence group."""
|
|
|
|
|
|
+class SequenceGroupOutput(ABC):
|
|
|
|
+ """The base class for model outputs associated with a sequence group."""
|
|
|
|
+
|
|
|
|
+ @abstractmethod
|
|
|
|
+ def __repr__(self) -> str:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ @abstractmethod
|
|
|
|
+ def __eq__(self, other: object) -> bool:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
|
|
|
+ """The model output associated with a completion sequence group."""
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
@@ -662,39 +768,93 @@ class SequenceGroupOutput:
|
|
prompt_logprobs: Optional[PromptLogprobs],
|
|
prompt_logprobs: Optional[PromptLogprobs],
|
|
) -> None:
|
|
) -> None:
|
|
self.samples = samples
|
|
self.samples = samples
|
|
|
|
+ # Prompt logprob for each prompt query token.
|
|
self.prompt_logprobs = prompt_logprobs
|
|
self.prompt_logprobs = prompt_logprobs
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
- return (f"SequenceGroupOutput(samples={self.samples}, "
|
|
|
|
|
|
+ return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
|
f"prompt_logprobs={self.prompt_logprobs})")
|
|
f"prompt_logprobs={self.prompt_logprobs})")
|
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
def __eq__(self, other: object) -> bool:
|
|
- if not isinstance(other, SequenceGroupOutput):
|
|
|
|
|
|
+ if not isinstance(other, CompletionSequenceGroupOutput):
|
|
raise NotImplementedError()
|
|
raise NotImplementedError()
|
|
return (self.samples == other.samples
|
|
return (self.samples == other.samples
|
|
and self.prompt_logprobs == other.prompt_logprobs)
|
|
and self.prompt_logprobs == other.prompt_logprobs)
|
|
|
|
|
|
|
|
|
|
|
|
+class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
|
|
|
+ """The model output associated with an embedding sequence group."""
|
|
|
|
+
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ embeddings: List[float],
|
|
|
|
+ ) -> None:
|
|
|
|
+ self.embeddings = embeddings
|
|
|
|
+
|
|
|
|
+ def __repr__(self) -> str:
|
|
|
|
+ return (f"EmbeddingSequenceGroupOutput("
|
|
|
|
+ f"embeddings_shape={len(self.embeddings)})")
|
|
|
|
+
|
|
|
|
+ def __eq__(self, other: object) -> bool:
|
|
|
|
+ if not isinstance(other, EmbeddingSequenceGroupOutput):
|
|
|
|
+ raise NotImplementedError()
|
|
|
|
+ return self.embeddings == other.embeddings
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@dataclass
|
|
|
|
+class IntermediateTensors:
|
|
|
|
+ """For all pipeline stages except the last, we need to return the hidden
|
|
|
|
+ states and residuals to be sent to the next stage. This data structure
|
|
|
|
+ contains the hidden states and residuals for a request.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ tensors: Dict[str, torch.Tensor]
|
|
|
|
+
|
|
|
|
+ def __getitem__(self, key: Union[str, slice]):
|
|
|
|
+ if isinstance(key, str):
|
|
|
|
+ return self.tensors[key]
|
|
|
|
+ elif isinstance(key, slice):
|
|
|
|
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
|
|
|
+
|
|
|
|
+ def __setitem__(self, key: str, value):
|
|
|
|
+ self.tensors[key] = value
|
|
|
|
+
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return len(self.tensors)
|
|
|
|
+
|
|
|
|
+ def __eq__(self, other: object):
|
|
|
|
+ return isinstance(other, self.__class__) and self
|
|
|
|
+
|
|
|
|
+ def __repr__(self) -> str:
|
|
|
|
+ return f"IntermediateTensors(tensors={self.tensors})"
|
|
|
|
+
|
|
|
|
+
|
|
@dataclass
|
|
@dataclass
|
|
class SamplerOutput:
|
|
class SamplerOutput:
|
|
"""For each sequence group, we generate a list of SequenceOutput object,
|
|
"""For each sequence group, we generate a list of SequenceOutput object,
|
|
each of which contains one possible candidate for the next token.
|
|
each of which contains one possible candidate for the next token.
|
|
|
|
|
|
- This datastructure implements methods so it can be used like a list, but
|
|
|
|
|
|
+ This data structure implements methods, so it can be used like a list, but
|
|
also has optional fields for device tensors.
|
|
also has optional fields for device tensors.
|
|
"""
|
|
"""
|
|
|
|
|
|
- outputs: List[SequenceGroupOutput]
|
|
|
|
|
|
+ outputs: List[CompletionSequenceGroupOutput]
|
|
|
|
|
|
# On-device tensor containing probabilities of each token.
|
|
# On-device tensor containing probabilities of each token.
|
|
- sampled_token_probs: Optional["torch.Tensor"] = None
|
|
|
|
|
|
+ sampled_token_probs: Optional[torch.Tensor] = None
|
|
|
|
+
|
|
|
|
+ # On-device tensor containing the logprobs of each token.
|
|
|
|
+ logprobs: Optional["torch.Tensor"] = None
|
|
|
|
|
|
# On-device tensor containing the sampled token ids.
|
|
# On-device tensor containing the sampled token ids.
|
|
- sampled_token_ids: Optional["torch.Tensor"] = None
|
|
|
|
|
|
+ sampled_token_ids: Optional[torch.Tensor] = None
|
|
|
|
|
|
# Spec decode metrics populated by workers.
|
|
# Spec decode metrics populated by workers.
|
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
|
|
|
|
|
|
+ # Optional last hidden states from the model.
|
|
|
|
+ hidden_states: Optional[torch.Tensor] = None
|
|
|
|
+
|
|
def __getitem__(self, idx: int):
|
|
def __getitem__(self, idx: int):
|
|
return self.outputs[idx]
|
|
return self.outputs[idx]
|
|
|
|
|
|
@@ -705,8 +865,8 @@ class SamplerOutput:
|
|
return len(self.outputs)
|
|
return len(self.outputs)
|
|
|
|
|
|
def __eq__(self, other: object):
|
|
def __eq__(self, other: object):
|
|
- return (isinstance(other, self.__class__)
|
|
|
|
- and self.outputs == other.outputs)
|
|
|
|
|
|
+ return isinstance(other,
|
|
|
|
+ self.__class__) and self.outputs == other.outputs
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
def __repr__(self) -> str:
|
|
"""Show the shape of a tensor instead of its values to reduce noise.
|
|
"""Show the shape of a tensor instead of its values to reduce noise.
|
|
@@ -720,3 +880,122 @@ class SamplerOutput:
|
|
f"sampled_token_probs={sampled_token_probs_repr}, "
|
|
f"sampled_token_probs={sampled_token_probs_repr}, "
|
|
f"sampled_token_ids={sampled_token_ids_repr}, "
|
|
f"sampled_token_ids={sampled_token_ids_repr}, "
|
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@dataclass
|
|
|
|
+class PoolerOutput:
|
|
|
|
+ """The output from a pooling operation in the embedding model."""
|
|
|
|
+ outputs: List[EmbeddingSequenceGroupOutput]
|
|
|
|
+
|
|
|
|
+ spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
|
|
+
|
|
|
|
+ def __getitem__(self, idx: int):
|
|
|
|
+ return self.outputs[idx]
|
|
|
|
+
|
|
|
|
+ def __setitem__(self, idx: int, value):
|
|
|
|
+ self.outputs[idx] = value
|
|
|
|
+
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return len(self.outputs)
|
|
|
|
+
|
|
|
|
+ def __eq__(self, other: object):
|
|
|
|
+ return isinstance(other,
|
|
|
|
+ self.__class__) and self.outputs == other.outputs
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_all_seq_ids(
|
|
|
|
+ seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
|
|
|
+ """Given a list of SequenceGroupMetadata, create a list of all
|
|
|
|
+ sequence ids.
|
|
|
|
+ """
|
|
|
|
+ return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_all_seq_ids_and_request_ids(
|
|
|
|
+ seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
|
|
+) -> Tuple[List[int], Dict[str, Set[int]]]:
|
|
|
|
+ """Given a list of SequenceGroupMetadata, create a list of all
|
|
|
|
+ sequence ids.
|
|
|
|
+ """
|
|
|
|
+ seq_ids: List[int] = []
|
|
|
|
+ request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
|
|
|
|
+ for sg in seq_group_metadata_list:
|
|
|
|
+ for seq_id in sg.seq_data:
|
|
|
|
+ seq_ids.append(seq_id)
|
|
|
|
+ request_id_seq_ids_mapping[sg.request_id].add(seq_id)
|
|
|
|
+ return seq_ids, request_id_seq_ids_mapping
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class HiddenStates:
|
|
|
|
+ """Hidden states corresponding to in-progress sequences.
|
|
|
|
+ Used in speculative decoding to pass hidden states from
|
|
|
|
+ the target model to the proposer model in the subsequent step.
|
|
|
|
+
|
|
|
|
+ seq_ids are the sequence ids of each entry of the batch
|
|
|
|
+ dimension of the hidden_states tensor"""
|
|
|
|
+
|
|
|
|
+ def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
|
+ hidden_states: torch.Tensor):
|
|
|
|
+ assert len(seq_group_metadata_list) == len(hidden_states)
|
|
|
|
+ self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
|
|
|
|
+ self.hidden_states: torch.Tensor = hidden_states
|
|
|
|
+
|
|
|
|
+ def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
|
+ hidden_states: torch.Tensor) -> None:
|
|
|
|
+ """Update hidden states from target model invocation."""
|
|
|
|
+ assert len(seq_group_metadata_list) == len(hidden_states)
|
|
|
|
+ self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
|
|
|
+ self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
|
|
|
+
|
|
|
|
+ def prune(self,
|
|
|
|
+ seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
|
|
|
+ """Prune to provided list of sequence ids."""
|
|
|
|
+ seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
|
|
|
+ if seq_ids != self.seq_ids:
|
|
|
|
+ # Batch contents changed - prune removed sequences.
|
|
|
|
+ index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
|
|
|
|
+ self.hidden_states = self.hidden_states[index]
|
|
|
|
+ self.seq_ids = seq_ids
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@dataclass
|
|
|
|
+class ExecuteModelRequest:
|
|
|
|
+ """The model execution request, containing CPU metadata only. The LLM
|
|
|
|
+ engine should create an instance of this class for each request batch."""
|
|
|
|
+ # The sequence group metadata list.
|
|
|
|
+ seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
|
|
+ # Blocks to swap in. List of CPU -> GPU block number.
|
|
|
|
+ blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
|
|
|
|
+ # Blocks to swap out. List of GPU -> CPU block number.
|
|
|
|
+ blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
|
|
|
+ # Blocks to copy. Source to dest block.
|
|
|
|
+ blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
|
|
|
+ # Virtual engine ID for pipeline parallel.
|
|
|
|
+ virtual_engine: int = 0
|
|
|
|
+ # The number of slots for lookahead decoding.
|
|
|
|
+ num_lookahead_slots: int = 0
|
|
|
|
+ # The number of requests in the running queue.
|
|
|
|
+ running_queue_size: int = 0
|
|
|
|
+ # Optional hidden states from prior step.
|
|
|
|
+ previous_hidden_states: Optional[HiddenStates] = None
|
|
|
|
+ # The number of forward steps to run.
|
|
|
|
+ num_steps: int = 1
|
|
|
|
+ # Finished request ids since last step.
|
|
|
|
+ finished_requests_ids: List[str] = field(default_factory=list)
|
|
|
|
+
|
|
|
|
+ def clone(
|
|
|
|
+ self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
|
|
+ ) -> "ExecuteModelRequest":
|
|
|
|
+ """Clone the request with a new sequence group metadata list."""
|
|
|
|
+ return ExecuteModelRequest(
|
|
|
|
+ seq_group_metadata_list=seq_group_metadata_list,
|
|
|
|
+ blocks_to_swap_in=self.blocks_to_swap_in.copy(),
|
|
|
|
+ blocks_to_swap_out=self.blocks_to_swap_out.copy(),
|
|
|
|
+ blocks_to_copy=self.blocks_to_copy.copy(),
|
|
|
|
+ virtual_engine=self.virtual_engine,
|
|
|
|
+ num_lookahead_slots=self.num_lookahead_slots,
|
|
|
|
+ running_queue_size=self.running_queue_size,
|
|
|
|
+ previous_hidden_states=self.previous_hidden_states,
|
|
|
|
+ num_steps=self.num_steps,
|
|
|
|
+ finished_requests_ids=self.finished_requests_ids,
|
|
|
|
+ )
|