|
@@ -14,18 +14,17 @@ import torch
|
|
|
|
|
|
from aphrodite.common.pooling_params import PoolingParams
|
|
from aphrodite.common.pooling_params import PoolingParams
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
|
|
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
|
|
from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
|
from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
|
from aphrodite.lora.request import LoRARequest
|
|
from aphrodite.lora.request import LoRARequest
|
|
from aphrodite.prompt_adapter.request import PromptAdapterRequest
|
|
from aphrodite.prompt_adapter.request import PromptAdapterRequest
|
|
-from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
if TYPE_CHECKING:
|
|
from aphrodite.inputs import LLMInputs
|
|
from aphrodite.inputs import LLMInputs
|
|
from aphrodite.multimodal import MultiModalDataDict
|
|
from aphrodite.multimodal import MultiModalDataDict
|
|
|
|
+ from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
|
|
|
|
|
|
|
|
-APHRODITE_TOKEN_ID_ARRAY_TYPE = "l"
|
|
|
|
-
|
|
|
|
@dataclass
|
|
@dataclass
|
|
class Logprob:
|
|
class Logprob:
|
|
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
|
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
|
@@ -202,6 +201,7 @@ class SequenceData(msgspec.Struct,
|
|
compatible with torch.long (2 bytes vs 4 bytes).
|
|
compatible with torch.long (2 bytes vs 4 bytes).
|
|
Beware!
|
|
Beware!
|
|
"""
|
|
"""
|
|
|
|
+ assert isinstance(self._output_token_ids, array)
|
|
return self._output_token_ids
|
|
return self._output_token_ids
|
|
|
|
|
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
|
@@ -536,7 +536,6 @@ class Sequence:
|
|
f"num_blocks={self.n_blocks}, ")
|
|
f"num_blocks={self.n_blocks}, ")
|
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
|
class SequenceGroupState(
|
|
class SequenceGroupState(
|
|
msgspec.Struct, omit_defaults=True):
|
|
msgspec.Struct, omit_defaults=True):
|
|
"""Mutable state tied to a specific sequence group"""
|
|
"""Mutable state tied to a specific sequence group"""
|
|
@@ -939,7 +938,6 @@ class SequenceGroupMetadata(
|
|
self.token_chunk_size = next(iter(
|
|
self.token_chunk_size = next(iter(
|
|
self.seq_data.values())).get_len()
|
|
self.seq_data.values())).get_len()
|
|
else:
|
|
else:
|
|
- self._token_chunk_size = 1
|
|
|
|
self.token_chunk_size = 1
|
|
self.token_chunk_size = 1
|
|
|
|
|
|
|
|
|
|
@@ -1022,6 +1020,7 @@ class CompletionSequenceGroupOutput(
|
|
omit_defaults=True,
|
|
omit_defaults=True,
|
|
array_like=True):
|
|
array_like=True):
|
|
"""The model output associated with a completion sequence group."""
|
|
"""The model output associated with a completion sequence group."""
|
|
|
|
+ __metaclass__ = SequenceGroupOutput
|
|
|
|
|
|
samples: List[SequenceOutput]
|
|
samples: List[SequenceOutput]
|
|
prompt_logprobs: Optional[PromptLogprobs]
|
|
prompt_logprobs: Optional[PromptLogprobs]
|
|
@@ -1056,7 +1055,6 @@ class EmbeddingSequenceGroupOutput(
|
|
return self.embeddings == other.embeddings
|
|
return self.embeddings == other.embeddings
|
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
|
class IntermediateTensors(
|
|
class IntermediateTensors(
|
|
msgspec.Struct,
|
|
msgspec.Struct,
|
|
omit_defaults=True,
|
|
omit_defaults=True,
|
|
@@ -1087,7 +1085,6 @@ class IntermediateTensors(
|
|
return f"IntermediateTensors(tensors={self.tensors})"
|
|
return f"IntermediateTensors(tensors={self.tensors})"
|
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
|
class SamplerOutput(
|
|
class SamplerOutput(
|
|
msgspec.Struct,
|
|
msgspec.Struct,
|
|
omit_defaults=True,
|
|
omit_defaults=True,
|
|
@@ -1112,7 +1109,7 @@ class SamplerOutput(
|
|
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
|
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
|
|
|
|
|
# Spec decode metrics populated by workers.
|
|
# Spec decode metrics populated by workers.
|
|
- spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
|
|
|
|
|
+ spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
|
|
|
|
# Optional last hidden states from the model.
|
|
# Optional last hidden states from the model.
|
|
hidden_states: Optional[torch.Tensor] = None
|
|
hidden_states: Optional[torch.Tensor] = None
|
|
@@ -1144,7 +1141,6 @@ class SamplerOutput(
|
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
|
class PoolerOutput(
|
|
class PoolerOutput(
|
|
msgspec.Struct,
|
|
msgspec.Struct,
|
|
omit_defaults=True,
|
|
omit_defaults=True,
|
|
@@ -1152,7 +1148,7 @@ class PoolerOutput(
|
|
"""The output from a pooling operation in the embedding model."""
|
|
"""The output from a pooling operation in the embedding model."""
|
|
outputs: List[EmbeddingSequenceGroupOutput]
|
|
outputs: List[EmbeddingSequenceGroupOutput]
|
|
|
|
|
|
- spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
|
|
|
|
|
+ spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
|
|
|
|
|
def __getitem__(self, idx: int):
|
|
def __getitem__(self, idx: int):
|
|
return self.outputs[idx]
|
|
return self.outputs[idx]
|
|
@@ -1233,7 +1229,6 @@ class HiddenStates(
|
|
self._seq_ids = seq_ids
|
|
self._seq_ids = seq_ids
|
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
|
class ExecuteModelRequest(
|
|
class ExecuteModelRequest(
|
|
msgspec.Struct,
|
|
msgspec.Struct,
|
|
omit_defaults=True,
|
|
omit_defaults=True,
|