Prechádzať zdrojové kódy

fix: circular import

AlpinDale 4 mesiacov pred
rodič
commit
2242dcbcf2

+ 6 - 11
aphrodite/common/sequence.py

@@ -14,18 +14,17 @@ import torch
 
 
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
 from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from aphrodite.inputs import LLMInputs
     from aphrodite.inputs import LLMInputs
     from aphrodite.multimodal import MultiModalDataDict
     from aphrodite.multimodal import MultiModalDataDict
+    from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 
 
 
-APHRODITE_TOKEN_ID_ARRAY_TYPE = "l"
-
 @dataclass
 @dataclass
 class Logprob:
 class Logprob:
     """Infos for supporting OpenAI compatible logprobs and token ranks.
     """Infos for supporting OpenAI compatible logprobs and token ranks.
@@ -202,6 +201,7 @@ class SequenceData(msgspec.Struct,
         compatible with torch.long (2 bytes vs 4 bytes).
         compatible with torch.long (2 bytes vs 4 bytes).
         Beware!
         Beware!
         """
         """
+        assert isinstance(self._output_token_ids, array)
         return self._output_token_ids
         return self._output_token_ids
 
 
     def append_token_id(self, token_id: int, logprob: float) -> None:
     def append_token_id(self, token_id: int, logprob: float) -> None:
@@ -536,7 +536,6 @@ class Sequence:
                 f"num_blocks={self.n_blocks}, ")
                 f"num_blocks={self.n_blocks}, ")
 
 
 
 
-@dataclass
 class SequenceGroupState(
 class SequenceGroupState(
     msgspec.Struct, omit_defaults=True):
     msgspec.Struct, omit_defaults=True):
     """Mutable state tied to a specific sequence group"""
     """Mutable state tied to a specific sequence group"""
@@ -939,7 +938,6 @@ class SequenceGroupMetadata(
                 self.token_chunk_size = next(iter(
                 self.token_chunk_size = next(iter(
                     self.seq_data.values())).get_len()
                     self.seq_data.values())).get_len()
             else:
             else:
-                self._token_chunk_size = 1
                 self.token_chunk_size = 1
                 self.token_chunk_size = 1
 
 
 
 
@@ -1022,6 +1020,7 @@ class CompletionSequenceGroupOutput(
     omit_defaults=True,
     omit_defaults=True,
     array_like=True):
     array_like=True):
     """The model output associated with a completion sequence group."""
     """The model output associated with a completion sequence group."""
+    __metaclass__ = SequenceGroupOutput
 
 
     samples: List[SequenceOutput]
     samples: List[SequenceOutput]
     prompt_logprobs: Optional[PromptLogprobs]
     prompt_logprobs: Optional[PromptLogprobs]
@@ -1056,7 +1055,6 @@ class EmbeddingSequenceGroupOutput(
         return self.embeddings == other.embeddings
         return self.embeddings == other.embeddings
 
 
 
 
-@dataclass
 class IntermediateTensors(
 class IntermediateTensors(
     msgspec.Struct,
     msgspec.Struct,
     omit_defaults=True,
     omit_defaults=True,
@@ -1087,7 +1085,6 @@ class IntermediateTensors(
         return f"IntermediateTensors(tensors={self.tensors})"
         return f"IntermediateTensors(tensors={self.tensors})"
 
 
 
 
-@dataclass
 class SamplerOutput(
 class SamplerOutput(
     msgspec.Struct,
     msgspec.Struct,
     omit_defaults=True,
     omit_defaults=True,
@@ -1112,7 +1109,7 @@ class SamplerOutput(
     sampled_token_ids_numpy: Optional[numpy.ndarray] = None
     sampled_token_ids_numpy: Optional[numpy.ndarray] = None
 
 
     # Spec decode metrics populated by workers.
     # Spec decode metrics populated by workers.
-    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
+    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
 
 
     # Optional last hidden states from the model.
     # Optional last hidden states from the model.
     hidden_states: Optional[torch.Tensor] = None
     hidden_states: Optional[torch.Tensor] = None
@@ -1144,7 +1141,6 @@ class SamplerOutput(
             f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
             f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
 
 
 
 
-@dataclass
 class PoolerOutput(
 class PoolerOutput(
     msgspec.Struct,
     msgspec.Struct,
     omit_defaults=True,
     omit_defaults=True,
@@ -1152,7 +1148,7 @@ class PoolerOutput(
     """The output from a pooling operation in the embedding model."""
     """The output from a pooling operation in the embedding model."""
     outputs: List[EmbeddingSequenceGroupOutput]
     outputs: List[EmbeddingSequenceGroupOutput]
 
 
-    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
+    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
 
 
     def __getitem__(self, idx: int):
     def __getitem__(self, idx: int):
         return self.outputs[idx]
         return self.outputs[idx]
@@ -1233,7 +1229,6 @@ class HiddenStates(
             self._seq_ids = seq_ids
             self._seq_ids = seq_ids
 
 
 
 
-@dataclass
 class ExecuteModelRequest(
 class ExecuteModelRequest(
     msgspec.Struct,
     msgspec.Struct,
     omit_defaults=True,
     omit_defaults=True,

+ 1 - 0
aphrodite/constants.py

@@ -0,0 +1 @@
+APHRODITE_TOKEN_ID_ARRAY_TYPE = "l"

+ 4 - 1
aphrodite/engine/args_tools.py

@@ -1,6 +1,7 @@
 import argparse
 import argparse
 import dataclasses
 import dataclasses
 import json
 import json
+import os
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
 from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
                     Union)
                     Union)
@@ -15,7 +16,6 @@ from aphrodite.common.config import (CacheConfig, ConfigFormat, DecodingConfig,
                                      TokenizerPoolConfig)
                                      TokenizerPoolConfig)
 from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
 from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.executor.executor_base import ExecutorBase
-from aphrodite.executor.ray_gpu_executor import APHRODITE_USE_RAY_SPMD_WORKER
 from aphrodite.quantization import QUANTIZATION_METHODS
 from aphrodite.quantization import QUANTIZATION_METHODS
 from aphrodite.transformers_utils.utils import check_gguf_file
 from aphrodite.transformers_utils.utils import check_gguf_file
 from aphrodite.triton_utils import HAS_TRITON
 from aphrodite.triton_utils import HAS_TRITON
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
     from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
     from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
 
 
 
 
+APHRODITE_USE_RAY_SPMD_WORKER = bool(
+    os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))
+
 def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
 def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
     if len(val) == 0:
     if len(val) == 0:
         return None
         return None

+ 1 - 1
aphrodite/executor/msgspec_utils.py

@@ -1,7 +1,7 @@
 from array import array
 from array import array
 from typing import Any, Type
 from typing import Any, Type
 
 
-from aphrodite.common.sequence import APHRODITE_TOKEN_ID_ARRAY_TYPE
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 
 
 
 
 def encode_hook(obj: Any) -> Any:
 def encode_hook(obj: Any) -> Any:

+ 2 - 4
aphrodite/inputs/registry.py

@@ -10,6 +10,8 @@ from torch import nn
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 from typing_extensions import TypeVar
 from typing_extensions import TypeVar
 
 
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
+
 from .data import LLMInputs
 from .data import LLMInputs
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -19,10 +21,6 @@ if TYPE_CHECKING:
 
 
 C = TypeVar("C", bound=PretrainedConfig)
 C = TypeVar("C", bound=PretrainedConfig)
 
 
-# NOTE: This has to match with sequence.py's `APHRODITE_TOKEN_ID_ARRAY_TYPE`.
-# We cannot import it here because of circular dependencies.
-APHRODITE_TOKEN_ID_ARRAY_TYPE = "l"
-
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
 class InputContext:
 class InputContext:

+ 2 - 2
aphrodite/modeling/models/blip.py

@@ -10,8 +10,8 @@ from transformers import Blip2VisionConfig, BlipVisionConfig
 from transformers.models.blip.modeling_blip import BlipAttention
 from transformers.models.blip.modeling_blip import BlipAttention
 
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.config import ModelConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SequenceData)
+from aphrodite.common.sequence import SequenceData
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import LLMInputs
 from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,

+ 2 - 2
aphrodite/modeling/models/blip2.py

@@ -9,9 +9,9 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
                                        SequenceData)
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor

+ 2 - 2
aphrodite/modeling/models/chameleon.py

@@ -11,10 +11,10 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
                                        SequenceData)
 from aphrodite.common.utils import print_warning_once
 from aphrodite.common.utils import print_warning_once
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul

+ 2 - 2
aphrodite/modeling/models/clip.py

@@ -10,8 +10,8 @@ from transformers import CLIPVisionConfig
 from transformers.models.clip.modeling_clip import CLIPAttention
 from transformers.models.clip.modeling_clip import CLIPAttention
 
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.config import ModelConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SequenceData)
+from aphrodite.common.sequence import SequenceData
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import LLMInputs
 from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,

+ 2 - 2
aphrodite/modeling/models/fuyu.py

@@ -27,9 +27,9 @@ from transformers import FuyuConfig, FuyuImageProcessor
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
                                        SequenceData)
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/minicpmv.py

@@ -39,9 +39,9 @@ from transformers.configuration_utils import PretrainedConfig
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
                                        SequenceData)
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ReplicatedLinear
 from aphrodite.modeling.layers.linear import ReplicatedLinear
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor

+ 2 - 2
aphrodite/modeling/models/siglip.py

@@ -14,8 +14,8 @@ from transformers.models.siglip.modeling_siglip import SiglipAttention
 from xformers.ops import memory_efficient_attention
 from xformers.ops import memory_efficient_attention
 
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.config import ModelConfig
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SequenceData)
+from aphrodite.common.sequence import SequenceData
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.inputs import LLMInputs
 from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn

+ 2 - 2
aphrodite/modeling/sampling_metadata.py

@@ -6,11 +6,11 @@ from typing import Dict, List, Optional, Tuple
 import torch
 import torch
 
 
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SequenceData, SequenceGroupMetadata)
+from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
 from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
 from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
                                     is_pin_memory_available,
                                     is_pin_memory_available,
                                     make_tensor_with_pad, maybe_expand_dim)
                                     make_tensor_with_pad, maybe_expand_dim)
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
 from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
 
 
 _SAMPLING_EPS = 1e-5
 _SAMPLING_EPS = 1e-5

+ 0 - 2
aphrodite/spec_decode/draft_model_runner.py

@@ -66,7 +66,6 @@ class TP1DraftModelRunner(ModelRunner):
         is_driver_worker: bool = False,
         is_driver_worker: bool = False,
         prompt_adapter_config: Optional[PromptAdapterConfig] = None,
         prompt_adapter_config: Optional[PromptAdapterConfig] = None,
         return_hidden_states: bool = False,
         return_hidden_states: bool = False,
-        **kwargs,
     ):
     ):
         if return_hidden_states:
         if return_hidden_states:
             raise ValueError(
             raise ValueError(
@@ -85,7 +84,6 @@ class TP1DraftModelRunner(ModelRunner):
             is_driver_worker=is_driver_worker,
             is_driver_worker=is_driver_worker,
             prompt_adapter_config=prompt_adapter_config,
             prompt_adapter_config=prompt_adapter_config,
             return_hidden_states=return_hidden_states,
             return_hidden_states=return_hidden_states,
-            **kwargs,  # needed for uneven TP
         )
         )
 
 
         self.flashinfer_decode_workspace_buffer = None
         self.flashinfer_decode_workspace_buffer = None

+ 1 - 0
aphrodite/spec_decode/medusa_worker.py

@@ -50,6 +50,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
         Returns the list of sampler output, one per layer, along with indicator
         Returns the list of sampler output, one per layer, along with indicator
         of whether torch tensor in sampler output need to be transposed in
         of whether torch tensor in sampler output need to be transposed in
         latter sampler_output_to_torch logic.
         latter sampler_output_to_torch logic.
+
         For medusa worker, this indicator shall be False.
         For medusa worker, this indicator shall be False.
         """
         """
         self._raise_if_unsupported(execute_model_req)
         self._raise_if_unsupported(execute_model_req)

+ 5 - 3
aphrodite/spec_decode/metrics.py

@@ -10,9 +10,9 @@ from aphrodite.modeling.layers.spec_decode_base_sampler import (
 
 
 
 
 class SpecDecodeWorkerMetrics(
 class SpecDecodeWorkerMetrics(
-    msgspec.Struct,
-    omit_defaults=True,
-    array_like=True):
+        msgspec.Struct,
+        omit_defaults=True,  # type: ignore[call-arg]
+        array_like=True):  # type: ignore[call-arg]
     """Dataclass holding metrics emitted from the spec decode worker.
     """Dataclass holding metrics emitted from the spec decode worker.
     """
     """
 
 
@@ -112,6 +112,7 @@ class AsyncMetricsCollector:
     def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
     def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
         """Copy rejection/typical-acceptance sampling metrics 
         """Copy rejection/typical-acceptance sampling metrics 
         (number of accepted tokens, etc) to CPU asynchronously.
         (number of accepted tokens, etc) to CPU asynchronously.
+
         Returns a CUDA event recording when the copy is complete.
         Returns a CUDA event recording when the copy is complete.
         """
         """
         assert self._copy_stream is not None
         assert self._copy_stream is not None
@@ -130,6 +131,7 @@ class AsyncMetricsCollector:
 
 
         aggregate_metrics_ready = torch.cuda.Event()
         aggregate_metrics_ready = torch.cuda.Event()
         aggregate_metrics_ready.record(self._copy_stream)
         aggregate_metrics_ready.record(self._copy_stream)
+
         return aggregate_metrics_ready
         return aggregate_metrics_ready
 
 
     def _collect_rejsample_metrics(
     def _collect_rejsample_metrics(

+ 2 - 0
aphrodite/spec_decode/mlp_speculator_worker.py

@@ -11,6 +11,7 @@ from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 
 
 class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
 class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
     """Worker for MLPSpeculator models.
     """Worker for MLPSpeculator models.
+
     Not currently compatible with LoRA or chunked prefill.
     Not currently compatible with LoRA or chunked prefill.
     """
     """
 
 
@@ -27,6 +28,7 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
         Returns the list of sampler output, one per layer, along with indicator
         Returns the list of sampler output, one per layer, along with indicator
         of whether torch tensor in sampler output need to be transposed in
         of whether torch tensor in sampler output need to be transposed in
         latter sampler_output_to_torch logic.
         latter sampler_output_to_torch logic.
+
         For mlp spec worker, this indicator shall be True.
         For mlp spec worker, this indicator shall be True.
         """
         """
         self._raise_if_unsupported(execute_model_req)
         self._raise_if_unsupported(execute_model_req)

+ 6 - 7
aphrodite/spec_decode/ngram_worker.py

@@ -12,7 +12,7 @@ from aphrodite.spec_decode.top1_proposer import Top1Proposer
 class NGramWorker(NonLLMProposerWorkerBase):
 class NGramWorker(NonLLMProposerWorkerBase):
     """NGramWorker provides a light drafter without need for model.
     """NGramWorker provides a light drafter without need for model.
 
 
-    Current NGramWorker only implement prompt lookup decoding,
+    Current NGramWorker only implements prompt lookup decoding,
     and in future we may also do RAG type drafter and other scenarios
     and in future we may also do RAG type drafter and other scenarios
     which don't rely on LLM model to give proposals.
     which don't rely on LLM model to give proposals.
     """
     """
@@ -36,9 +36,9 @@ class NGramWorker(NonLLMProposerWorkerBase):
         self.device = torch.device(f"cuda:{self.local_rank}")
         self.device = torch.device(f"cuda:{self.local_rank}")
         self.load_model = lambda *args, **kwargs: None
         self.load_model = lambda *args, **kwargs: None
 
 
-        # Current only support Top1Proposer
+        # Current NGramWorker only supports Top1Proposer
         self._proposer = Top1Proposer(
         self._proposer = Top1Proposer(
-            weakref.proxy(self),
+            weakref.proxy(self),  # type: ignore[arg-type]
             device=self.device,
             device=self.device,
             vocab_size=self.vocab_size,
             vocab_size=self.vocab_size,
         )
         )
@@ -50,7 +50,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
         # Unused parameter. NGramWorker does not use the KV Cache and
         # Unused parameter. NGramWorker does not use the KV Cache and
         # therefore does not need this parameter.
         # therefore does not need this parameter.
         seq_ids_with_bonus_token_in_last_step: Set[int],
         seq_ids_with_bonus_token_in_last_step: Set[int],
-    ) -> Tuple[Optional[List[SamplerOutput]], bool]:
+    ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
         """NGram match algo to pick proposal candidate. Returns the list of
         """NGram match algo to pick proposal candidate. Returns the list of
         sampler output, one per SequenceGroupMetadata.
         sampler output, one per SequenceGroupMetadata.
 
 
@@ -60,8 +60,8 @@ class NGramWorker(NonLLMProposerWorkerBase):
         self._raise_if_unsupported(execute_model_req)
         self._raise_if_unsupported(execute_model_req)
 
 
         has_spec_out = False
         has_spec_out = False
-        token_id_list = []
-        token_prob_list = []
+        token_id_list: List[Optional[torch.Tensor]] = []
+        token_prob_list: List[Optional[torch.Tensor]] = []
         for idx, seq_group_metadata in enumerate(
         for idx, seq_group_metadata in enumerate(
                 execute_model_req.seq_group_metadata_list):
                 execute_model_req.seq_group_metadata_list):
             seq_data = next(iter(seq_group_metadata.seq_data.values()))
             seq_data = next(iter(seq_group_metadata.seq_data.values()))
@@ -142,7 +142,6 @@ class NGramWorker(NonLLMProposerWorkerBase):
         """Produce speculations given an input batch of sequences. The number of
         """Produce speculations given an input batch of sequences. The number of
         speculative tokens per sequence is determined by max_proposal_len.
         speculative tokens per sequence is determined by max_proposal_len.
         """
         """
-
         return self._proposer.get_spec_proposals(
         return self._proposer.get_spec_proposals(
             execute_model_req, seq_ids_with_bonus_token_in_last_step)
             execute_model_req, seq_ids_with_bonus_token_in_last_step)
 
 

+ 0 - 19
aphrodite/spec_decode/proposer_worker_base.py

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
 from typing import List, Optional, Set, Tuple
 from typing import List, Optional, Set, Tuple
 
 
 from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
-from aphrodite.lora.request import LoRARequest
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 
 
@@ -33,15 +32,6 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
         """Implementation optional"""
         """Implementation optional"""
         pass
         pass
 
 
-    def add_lora(self, lora_request: LoRARequest) -> bool:
-        raise ValueError(f"{type(self)} does not support LoRA")
-
-    def remove_lora(self, lora_id: int) -> bool:
-        raise ValueError(f"{type(self)} does not support LoRA")
-
-    def list_loras(self) -> Set[int]:
-        raise ValueError(f"{type(self)} does not support LoRA")
-
 
 
 class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
 class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
     """Proposer worker which does not use a model with kvcache"""
     """Proposer worker which does not use a model with kvcache"""
@@ -63,12 +53,3 @@ class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
 
 
     def get_cache_block_size_bytes(self) -> int:
     def get_cache_block_size_bytes(self) -> int:
         return 0
         return 0
-
-    def add_lora(self, lora_request: LoRARequest) -> bool:
-        raise ValueError(f"{type(self)} does not support LoRA")
-
-    def remove_lora(self, lora_id: int) -> bool:
-        raise ValueError(f"{type(self)} does not support LoRA")
-
-    def list_loras(self) -> Set[int]:
-        raise ValueError(f"{type(self)} does not support LoRA")

+ 2 - 0
aphrodite/spec_decode/smaller_tp_proposer_worker.py

@@ -16,6 +16,7 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
     """Class which allows a speculative draft model to run with smaller tensor
     """Class which allows a speculative draft model to run with smaller tensor
     parallel degree than target model.
     parallel degree than target model.
     This reduces the communication overhead of small draft models.
     This reduces the communication overhead of small draft models.
+
     To implement this feature, this class differs behavior based on is_dummy
     To implement this feature, this class differs behavior based on is_dummy
     flag, where dummy means worker that does not participate draft generation.
     flag, where dummy means worker that does not participate draft generation.
     Participating workers use a smaller tp group by patching Aphrodite's tensor
     Participating workers use a smaller tp group by patching Aphrodite's tensor
@@ -38,6 +39,7 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
 
 
     def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
     def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
         """Create a SmallerTpProposerWorker.
         """Create a SmallerTpProposerWorker.
+
         Args:
         Args:
             worker (MultiStepWorker): an actual worker wrapped with this class
             worker (MultiStepWorker): an actual worker wrapped with this class
             draft_ranks (List[int]): if this value is given, only the GPU ranks
             draft_ranks (List[int]): if this value is given, only the GPU ranks

+ 2 - 2
aphrodite/spec_decode/spec_decode_worker.py

@@ -173,8 +173,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
             proposer_worker,
             proposer_worker,
             scorer_worker,
             scorer_worker,
             disable_logprobs=disable_logprobs,
             disable_logprobs=disable_logprobs,
-            disable_by_batch_size=disable_by_batch_size,
             disable_log_stats=disable_log_stats,
             disable_log_stats=disable_log_stats,
+            disable_by_batch_size=disable_by_batch_size,
             spec_decode_sampler=spec_decode_sampler,
             spec_decode_sampler=spec_decode_sampler,
             allow_zero_draft_token_step=allow_zero_draft_token_step)
             allow_zero_draft_token_step=allow_zero_draft_token_step)
 
 
@@ -497,7 +497,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         for both speculation cases (num_lookahead_slots>0) and non-speculation
         for both speculation cases (num_lookahead_slots>0) and non-speculation
         cases (e.g. prefill).
         cases (e.g. prefill).
 
 
-        Returns True iff there are remaining sequences to process.
+        Returns True if there are remaining sequences to process.
         """
         """
         assert self.rank != self._driver_rank
         assert self.rank != self._driver_rank
 
 

+ 2 - 3
aphrodite/spec_decode/top1_proposer.py

@@ -108,8 +108,7 @@ class Top1Proposer(SpeculativeProposer):
             proposal_token_ids=proposal_tokens,
             proposal_token_ids=proposal_tokens,
             proposal_probs=proposal_probs,
             proposal_probs=proposal_probs,
             proposal_lens=proposal_lens,
             proposal_lens=proposal_lens,
-            no_proposals=maybe_sampler_output is None,
-        )
+            no_proposals=maybe_sampler_output is None)
 
 
         return proposals
         return proposals
 
 
@@ -139,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
 
 
             # Currently only proposal lens of 0 or the global batch proposal len
             # Currently only proposal lens of 0 or the global batch proposal len
             # are supported.
             # are supported.
-            # If max_proposal_len is defined, then we shall no exccess this
+            # If max_proposal_len is defined, then we shall not exceed this
             # quota for nonzero_proposal
             # quota for nonzero_proposal
             new_k = 0
             new_k = 0
             if (self.max_proposal_len is None
             if (self.max_proposal_len is None

+ 2 - 2
aphrodite/spec_decode/util.py

@@ -19,10 +19,10 @@ def get_all_num_logprobs(
     sequence.
     sequence.
     """
     """
 
 
-    all_num_logprobs = []
+    all_num_logprobs: List[int] = []
     for seq_group_metadata in seq_group_metadata_list:
     for seq_group_metadata in seq_group_metadata_list:
         num_logprobs = seq_group_metadata.sampling_params.logprobs
         num_logprobs = seq_group_metadata.sampling_params.logprobs
-        if seq_group_metadata.sampling_params.logprobs is None:
+        if num_logprobs is None:
             num_logprobs = 0
             num_logprobs = 0
         all_num_logprobs.append(num_logprobs)
         all_num_logprobs.append(num_logprobs)
 
 

+ 2 - 2
tests/samplers/test_sampler.py

@@ -8,10 +8,10 @@ import pytest
 import torch
 import torch
 from transformers import GenerationConfig, GenerationMixin
 from transformers import GenerationConfig, GenerationMixin
 
 
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SamplingParams, SequenceData,
+from aphrodite.common.sequence import (SamplingParams, SequenceData,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import Counter, is_pin_memory_available
 from aphrodite.common.utils import Counter, is_pin_memory_available
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_random_seed
 from aphrodite.modeling.utils import set_random_seed

+ 2 - 2
tests/spec_decode/utils.py

@@ -8,12 +8,12 @@ from unittest.mock import MagicMock
 import torch
 import torch
 
 
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       CompletionSequenceGroupOutput, Logprob,
+from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        SamplerOutput, SequenceData,
                                        SamplerOutput, SequenceData,
                                        SequenceGroupMetadata, SequenceOutput)
                                        SequenceGroupMetadata, SequenceOutput)
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port)
                                     get_open_port)
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.modeling.utils import set_random_seed
 from aphrodite.modeling.utils import set_random_seed
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.task_handler.cache_engine import CacheEngine

+ 2 - 2
tests/worker/test_encoder_decoder_model_runner.py

@@ -4,10 +4,10 @@ from typing import List
 import pytest
 import pytest
 import torch
 import torch
 
 
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SamplingParams, SequenceData,
+from aphrodite.common.sequence import (SamplingParams, SequenceData,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import is_cpu
 from aphrodite.common.utils import is_cpu
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.task_handler.enc_dec_model_runner import (
 from aphrodite.task_handler.enc_dec_model_runner import (
     EncoderDecoderModelRunner)
     EncoderDecoderModelRunner)

+ 2 - 2
tests/worker/test_model_runner.py

@@ -4,10 +4,10 @@ from typing import List
 import pytest
 import pytest
 import torch
 import torch
 
 
-from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SamplingParams, SequenceData,
+from aphrodite.common.sequence import (SamplingParams, SequenceData,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import get_open_port
 from aphrodite.common.utils import get_open_port
+from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.distributed.parallel_state import (
 from aphrodite.distributed.parallel_state import (
     ensure_model_parallel_initialized, init_distributed_environment)
     ensure_model_parallel_initialized, init_distributed_environment)
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs