Переглянути джерело

core: support logprobs with multi-step scheduling (#963)

* deferred sampler results

* fix imports and implement within multistep worker

* update tests

* fix test

* fix sequence test

* fix unrelated gguf ruff issue
AlpinDale 2 місяців тому
батько
коміт
0dfa6b60ec
100 змінених файлів з 637 додано та 320 видалено
  1. 0 66
      aphrodite/common/sequence.py
  2. 4 3
      aphrodite/engine/aphrodite_engine.py
  3. 2 1
      aphrodite/engine/async_aphrodite.py
  4. 11 3
      aphrodite/engine/output_processor/multi_step.py
  5. 44 17
      aphrodite/engine/output_processor/single_step.py
  6. 2 2
      aphrodite/engine/output_processor/util.py
  7. 1 1
      aphrodite/engine/protocol.py
  8. 2 1
      aphrodite/executor/cpu_executor.py
  9. 2 1
      aphrodite/executor/distributed_gpu_executor.py
  10. 2 1
      aphrodite/executor/executor_base.py
  11. 2 2
      aphrodite/executor/gpu_executor.py
  12. 2 1
      aphrodite/executor/multiproc_gpu_executor.py
  13. 2 1
      aphrodite/executor/neuron_executor.py
  14. 2 1
      aphrodite/executor/openvino_executor.py
  15. 2 1
      aphrodite/executor/ray_gpu_executor.py
  16. 2 1
      aphrodite/executor/ray_tpu_executor.py
  17. 2 1
      aphrodite/executor/tpu_executor.py
  18. 2 2
      aphrodite/executor/xpu_executor.py
  19. 247 44
      aphrodite/modeling/layers/sampler.py
  20. 1 2
      aphrodite/modeling/model_loader/neuron.py
  21. 1 2
      aphrodite/modeling/model_loader/openvino.py
  22. 2 2
      aphrodite/modeling/models/arctic.py
  23. 2 2
      aphrodite/modeling/models/baichuan.py
  24. 2 2
      aphrodite/modeling/models/bart.py
  25. 2 3
      aphrodite/modeling/models/blip2.py
  26. 2 2
      aphrodite/modeling/models/bloom.py
  27. 2 3
      aphrodite/modeling/models/chameleon.py
  28. 2 2
      aphrodite/modeling/models/chatglm.py
  29. 2 2
      aphrodite/modeling/models/commandr.py
  30. 2 2
      aphrodite/modeling/models/dbrx.py
  31. 2 2
      aphrodite/modeling/models/deepseek.py
  32. 2 2
      aphrodite/modeling/models/deepseek_v2.py
  33. 2 1
      aphrodite/modeling/models/eagle.py
  34. 2 2
      aphrodite/modeling/models/exaone.py
  35. 2 2
      aphrodite/modeling/models/falcon.py
  36. 2 2
      aphrodite/modeling/models/fuyu.py
  37. 2 2
      aphrodite/modeling/models/gemma.py
  38. 2 2
      aphrodite/modeling/models/gemma2.py
  39. 2 2
      aphrodite/modeling/models/gpt2.py
  40. 2 2
      aphrodite/modeling/models/gpt_bigcode.py
  41. 2 2
      aphrodite/modeling/models/gpt_j.py
  42. 2 2
      aphrodite/modeling/models/gpt_neox.py
  43. 2 2
      aphrodite/modeling/models/internlm2.py
  44. 2 1
      aphrodite/modeling/models/internvl.py
  45. 2 2
      aphrodite/modeling/models/jais.py
  46. 2 2
      aphrodite/modeling/models/jamba.py
  47. 2 2
      aphrodite/modeling/models/llama.py
  48. 2 1
      aphrodite/modeling/models/llava.py
  49. 2 1
      aphrodite/modeling/models/llava_next.py
  50. 2 2
      aphrodite/modeling/models/mamba.py
  51. 1 1
      aphrodite/modeling/models/medusa.py
  52. 2 2
      aphrodite/modeling/models/minicpm.py
  53. 2 3
      aphrodite/modeling/models/minicpmv.py
  54. 2 2
      aphrodite/modeling/models/mixtral.py
  55. 2 2
      aphrodite/modeling/models/mixtral_quant.py
  56. 1 2
      aphrodite/modeling/models/mlp_speculator.py
  57. 2 2
      aphrodite/modeling/models/mpt.py
  58. 2 2
      aphrodite/modeling/models/nemotron.py
  59. 2 2
      aphrodite/modeling/models/olmo.py
  60. 2 2
      aphrodite/modeling/models/olmoe.py
  61. 2 2
      aphrodite/modeling/models/opt.py
  62. 2 2
      aphrodite/modeling/models/orion.py
  63. 2 2
      aphrodite/modeling/models/paligemma.py
  64. 2 2
      aphrodite/modeling/models/persimmon.py
  65. 2 2
      aphrodite/modeling/models/phi.py
  66. 2 2
      aphrodite/modeling/models/phi3_small.py
  67. 2 2
      aphrodite/modeling/models/phi3v.py
  68. 2 2
      aphrodite/modeling/models/qwen.py
  69. 2 2
      aphrodite/modeling/models/qwen2.py
  70. 2 2
      aphrodite/modeling/models/qwen2_moe.py
  71. 2 2
      aphrodite/modeling/models/solar.py
  72. 2 2
      aphrodite/modeling/models/stablelm.py
  73. 2 2
      aphrodite/modeling/models/starcoder2.py
  74. 2 1
      aphrodite/modeling/models/ultravox.py
  75. 2 2
      aphrodite/modeling/models/xverse.py
  76. 2 1
      aphrodite/quantization/gguf.py
  77. 3 3
      aphrodite/spec_decode/batch_expansion.py
  78. 2 2
      aphrodite/spec_decode/draft_model_runner.py
  79. 2 1
      aphrodite/spec_decode/medusa_worker.py
  80. 2 1
      aphrodite/spec_decode/mlp_speculator_worker.py
  81. 2 2
      aphrodite/spec_decode/multi_step_worker.py
  82. 2 1
      aphrodite/spec_decode/ngram_worker.py
  83. 2 1
      aphrodite/spec_decode/proposer_worker_base.py
  84. 2 1
      aphrodite/spec_decode/smaller_tp_proposer_worker.py
  85. 2 2
      aphrodite/spec_decode/spec_decode_worker.py
  86. 2 1
      aphrodite/spec_decode/top1_proposer.py
  87. 2 2
      aphrodite/spec_decode/util.py
  88. 2 1
      aphrodite/task_handler/cpu_model_runner.py
  89. 2 1
      aphrodite/task_handler/enc_dec_model_runner.py
  90. 2 1
      aphrodite/task_handler/model_runner.py
  91. 2 1
      aphrodite/task_handler/model_runner_base.py
  92. 147 29
      aphrodite/task_handler/multi_step_model_runner.py
  93. 2 1
      aphrodite/task_handler/multi_step_worker.py
  94. 2 1
      aphrodite/task_handler/neuron_model_runner.py
  95. 2 1
      aphrodite/task_handler/openvino_model_runner.py
  96. 2 1
      aphrodite/task_handler/openvino_worker.py
  97. 2 2
      aphrodite/task_handler/tpu_model_runner.py
  98. 2 1
      aphrodite/task_handler/worker.py
  99. 2 2
      aphrodite/task_handler/worker_base.py
  100. 2 1
      aphrodite/task_handler/xpu_model_runner.py

+ 0 - 66
aphrodite/common/sequence.py

@@ -1046,72 +1046,6 @@ class IntermediateTensors(
         return f"IntermediateTensors(tensors={self.tensors})"
         return f"IntermediateTensors(tensors={self.tensors})"
 
 
 
 
-class SamplerOutput(
-        msgspec.Struct,
-        omit_defaults=True,  # type: ignore[call-arg]
-        array_like=True):  # type: ignore[call-arg]
-    """For each sequence group, we generate a list of SequenceOutput object,
-    each of which contains one possible candidate for the next token.
-
-    This data structure implements methods, so it can be used like a list, but
-    also has optional fields for device tensors.
-    """
-
-    outputs: List[CompletionSequenceGroupOutput]
-
-    # On-device tensor containing probabilities of each token.
-    sampled_token_probs: Optional[torch.Tensor] = None
-
-    # On-device tensor containing the logprobs of each token.
-    logprobs: Optional["torch.Tensor"] = None
-
-    # On-device tensor containing the sampled token ids.
-    sampled_token_ids: Optional[torch.Tensor] = None
-    # CPU tensor containing the sampled token ids. Used during multi-step to
-    # return the sampled token ids from last rank to AsyncLLMEngine to be
-    # 'broadcasted' to all other PP ranks for next step.
-    sampled_token_ids_cpu: Optional[torch.Tensor] = None
-
-    # Spec decode metrics populated by workers.
-    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
-
-    # Optional last hidden states from the model.
-    hidden_states: Optional[torch.Tensor] = None
-
-    # Optional prefill hidden states from the model
-    # (used for models like EAGLE).
-    prefill_hidden_states: Optional[torch.Tensor] = None
-
-    # Time taken in the forward pass for this across all workers
-    model_forward_time: Optional[float] = None
-
-    def __getitem__(self, idx: int):
-        return self.outputs[idx]
-
-    def __setitem__(self, idx: int, value):
-        self.outputs[idx] = value
-
-    def __len__(self):
-        return len(self.outputs)
-
-    def __eq__(self, other: object):
-        return isinstance(other,
-                          self.__class__) and self.outputs == other.outputs
-
-    def __repr__(self) -> str:
-        """Show the shape of a tensor instead of its values to reduce noise.
-        """
-        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
-                                    else self.sampled_token_probs.shape)
-        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
-                                  self.sampled_token_ids.shape)
-        return (
-            f"SamplerOutput(outputs={self.outputs}, "
-            f"sampled_token_probs={sampled_token_probs_repr}, "
-            f"sampled_token_ids={sampled_token_ids_repr}, "
-            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
-
-
 class PoolerOutput(
 class PoolerOutput(
         msgspec.Struct,
         msgspec.Struct,
         omit_defaults=True,  # type: ignore[call-arg]
         omit_defaults=True,  # type: ignore[call-arg]

+ 4 - 3
aphrodite/engine/aphrodite_engine.py

@@ -24,9 +24,9 @@ from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
 from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
-                                       ExecuteModelRequest, SamplerOutput,
-                                       Sequence, SequenceGroup,
-                                       SequenceGroupMetadata, SequenceStatus)
+                                       ExecuteModelRequest, Sequence,
+                                       SequenceGroup, SequenceGroupMetadata,
+                                       SequenceStatus)
 from aphrodite.common.utils import Counter, Device
 from aphrodite.common.utils import Counter, Device
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics_types import StatLoggerBase, Stats
 from aphrodite.engine.metrics_types import StatLoggerBase, Stats
@@ -42,6 +42,7 @@ from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
                               SingletonPromptInputs)
                               SingletonPromptInputs)
 from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.multimodal import MultiModalDataDict
 from aphrodite.multimodal import MultiModalDataDict
 from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
 from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
                                             SchedulerOutputs)
                                             SchedulerOutputs)

+ 2 - 1
aphrodite/engine/async_aphrodite.py

@@ -14,7 +14,7 @@ from aphrodite.common.config import (DecodingConfig, EngineConfig, LoRAConfig,
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import print_warning_once
 from aphrodite.common.utils import print_warning_once
 from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
 from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
                                                DecoderPromptComponents,
                                                DecoderPromptComponents,
@@ -29,6 +29,7 @@ from aphrodite.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                               SingletonPromptInputs)
                               SingletonPromptInputs)
 from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.transformers_utils.tokenizer import AnyTokenizer
 from aphrodite.transformers_utils.tokenizer import AnyTokenizer

+ 11 - 3
aphrodite/engine/output_processor/multi_step.py

@@ -11,6 +11,8 @@ from aphrodite.common.sequence import (Sequence, SequenceGroup,
 from aphrodite.common.utils import Counter
 from aphrodite.common.utils import Counter
 from aphrodite.engine.output_processor.interfaces import (
 from aphrodite.engine.output_processor.interfaces import (
     SequenceGroupOutputProcessor)
     SequenceGroupOutputProcessor)
+from aphrodite.engine.output_processor.single_step import (
+    single_step_process_prompt_logprob)
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
 from aphrodite.transformers_utils.detokenizer import Detokenizer
@@ -46,9 +48,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
 
 
     def process_prompt_logprob(self, seq_group: SequenceGroup,
     def process_prompt_logprob(self, seq_group: SequenceGroup,
                                outputs: List[SequenceGroupOutput]) -> None:
                                outputs: List[SequenceGroupOutput]) -> None:
-        # TODO: Prompt logprob currently not implemented in multi step
-        # workers.
-        self._log_prompt_logprob_unsupported_warning_once()
+        """Process prompt logprobs associated with each step of a multi-step-
+        scheduled computation.
+        Args:
+          seq_group: the outputs are associated with this :class:`SequenceGroup`
+          outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
+        """
+        for output in outputs:
+            # Concatenate single-step prompt logprob processing results.
+            single_step_process_prompt_logprob(self, seq_group, output)
 
 
     @staticmethod
     @staticmethod
     @functools.lru_cache()
     @functools.lru_cache()

+ 44 - 17
aphrodite/engine/output_processor/single_step.py

@@ -13,6 +13,42 @@ from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
 from aphrodite.transformers_utils.detokenizer import Detokenizer
 
 
 
 
+def single_step_process_prompt_logprob(
+        sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
+        output: SequenceGroupOutput) -> None:
+    """Process prompt logprobs associated with the :class:`SequenceGroupOutput`
+    for a given step.
+    Do nothing if the output has no prompt logprobs.
+    Account for the fact that transformers do not compute first-token logprobs.
+
+    Args:
+      sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
+      seq_group: the output is associated with this :class:`SequenceGroup`
+      output: the :class:`SequenceGroupOutput` for a single scheduler step
+    """
+    prompt_logprobs = output.prompt_logprobs
+
+    # If this is the first (or only) "chunk" of the prefill, we need
+    # to prepend None to the list of prompt logprobs. The reason for this
+    # is that for N prompt tokens, the Sampler will generate N-1 total
+    # prompt logprobs during prefill since the token at idx 0 will not
+    # have a logprob associated with it.
+    if prompt_logprobs is not None:
+        if not seq_group.prompt_logprobs:
+            prompt_logprobs = [None] + prompt_logprobs
+            seq_group.prompt_logprobs = []
+
+        assert hasattr(sg_output_proc, 'detokenizer')
+        if (seq_group.sampling_params.detokenize
+                and sg_output_proc.detokenizer):
+            sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
+                seq_group,
+                prompt_logprobs,
+                position_offset=len(seq_group.prompt_logprobs))
+
+        seq_group.prompt_logprobs.extend(prompt_logprobs)
+
+
 class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
 class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
     """SequenceGroupOutputProcessor which handles "output processing" logic,
     """SequenceGroupOutputProcessor which handles "output processing" logic,
     which happens after the model returns generated token ids and before
     which happens after the model returns generated token ids and before
@@ -57,25 +93,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
 
 
     def process_prompt_logprob(self, seq_group: SequenceGroup,
     def process_prompt_logprob(self, seq_group: SequenceGroup,
                                outputs: List[SequenceGroupOutput]) -> None:
                                outputs: List[SequenceGroupOutput]) -> None:
+        """Process prompt logprobs associated with one step of a single-step-
+        scheduled computation.
+        
+        Args:
+          seq_group: the output is associated with this :class:`SequenceGroup`
+          output: the :class:`SequenceGroupOutput` for a single scheduler step
+        """
         assert len(outputs) == 1, ("Single step should only has 1 output.")
         assert len(outputs) == 1, ("Single step should only has 1 output.")
         output = outputs[0]
         output = outputs[0]
-        prompt_logprobs = output.prompt_logprobs
-
-        # If this is the first (or only) "chunk" of the prefill, we need
-        # to prepend None to the list of prompt logprobs. The reason for this
-        # is that for N prompt tokens, the Sampler will generate N-1 total
-        # prompt logprobs during prefill since the token at idx 0 will not
-        # have a logprob associated with it.
-        if prompt_logprobs is not None:
-            if not seq_group.prompt_logprobs:
-                prompt_logprobs = [None] + prompt_logprobs
-                seq_group.prompt_logprobs = []
-            if seq_group.sampling_params.detokenize and self.detokenizer:
-                self.detokenizer.decode_prompt_logprobs_inplace(
-                    seq_group,
-                    prompt_logprobs,
-                    position_offset=len(seq_group.prompt_logprobs))
-            seq_group.prompt_logprobs.extend(prompt_logprobs)
+        single_step_process_prompt_logprob(self, seq_group, output)
 
 
     def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
     def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
                                         outputs: SequenceGroupOutput,
                                         outputs: SequenceGroupOutput,

+ 2 - 2
aphrodite/engine/output_processor/util.py

@@ -2,8 +2,8 @@ from typing import List
 from typing import Sequence as GenericSequence
 from typing import Sequence as GenericSequence
 from typing import Union
 from typing import Union
 
 
-from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
-                                       SequenceGroupOutput)
+from aphrodite.common.sequence import PoolerOutput, SequenceGroupOutput
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 
 
 def create_output_by_sequence_group(
 def create_output_by_sequence_group(

+ 1 - 1
aphrodite/engine/protocol.py

@@ -6,9 +6,9 @@ from aphrodite.common.config import DecodingConfig, ModelConfig
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.inputs.data import PromptInputs
 from aphrodite.inputs.data import PromptInputs
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 

+ 2 - 1
aphrodite/executor/cpu_executor.py

@@ -7,7 +7,7 @@ from loguru import logger
 
 
 import aphrodite.common.envs as envs
 import aphrodite.common.envs as envs
 from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
 from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
 from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
                                     get_distributed_init_method, get_open_port,
                                     get_distributed_init_method, get_open_port,
                                     make_async)
                                     make_async)
@@ -16,6 +16,7 @@ from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
                                                        ResultHandler,
                                                        ResultHandler,
                                                        WorkerMonitor)
                                                        WorkerMonitor)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.task_handler.worker_base import WorkerWrapperBase
 from aphrodite.task_handler.worker_base import WorkerWrapperBase
 
 

+ 2 - 1
aphrodite/executor/distributed_gpu_executor.py

@@ -4,10 +4,11 @@ from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
 
 
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 
 
 class DistributedGPUExecutor(GPUExecutor):
 class DistributedGPUExecutor(GPUExecutor):

+ 2 - 1
aphrodite/executor/executor_base.py

@@ -5,8 +5,9 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
                                      SpeculativeConfig)
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 
 
 

+ 2 - 2
aphrodite/executor/gpu_executor.py

@@ -2,12 +2,12 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
 
 
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
-                                       SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase
 from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase
 
 

+ 2 - 1
aphrodite/executor/multiproc_gpu_executor.py

@@ -9,7 +9,7 @@ from typing import Any, List, Optional
 import torch
 import torch
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (_run_task_with_lock,
 from aphrodite.common.utils import (_run_task_with_lock,
                                     cuda_device_count_stateless,
                                     cuda_device_count_stateless,
                                     get_aphrodite_instance_id,
                                     get_aphrodite_instance_id,
@@ -21,6 +21,7 @@ from aphrodite.executor.gpu_executor import create_worker
 from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
 from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
                                                        ResultHandler,
                                                        ResultHandler,
                                                        WorkerMonitor)
                                                        WorkerMonitor)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.triton_utils import maybe_set_triton_cache_manager
 from aphrodite.triton_utils import maybe_set_triton_cache_manager
 
 
 
 

+ 2 - 1
aphrodite/executor/neuron_executor.py

@@ -1,10 +1,11 @@
 from typing import List, Set, Tuple
 from typing import List, Set, Tuple
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 
 
 class NeuronExecutor(ExecutorBase):
 class NeuronExecutor(ExecutorBase):

+ 2 - 1
aphrodite/executor/openvino_executor.py

@@ -7,11 +7,12 @@ from loguru import logger
 
 
 import aphrodite.common.envs as envs
 import aphrodite.common.envs as envs
 from aphrodite.common.config import CacheConfig, ModelConfig
 from aphrodite.common.config import CacheConfig, ModelConfig
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (GiB_bytes, get_distributed_init_method,
 from aphrodite.common.utils import (GiB_bytes, get_distributed_init_method,
                                     get_ip, get_open_port, make_async)
                                     get_ip, get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 APHRODITE_OPENVINO_KVCACHE_SPACE = envs.APHRODITE_OPENVINO_KVCACHE_SPACE
 APHRODITE_OPENVINO_KVCACHE_SPACE = envs.APHRODITE_OPENVINO_KVCACHE_SPACE
 APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION = (
 APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION = (

+ 2 - 1
aphrodite/executor/ray_gpu_executor.py

@@ -8,7 +8,7 @@ import msgspec
 from loguru import logger
 from loguru import logger
 
 
 import aphrodite.common.envs as envs
 import aphrodite.common.envs as envs
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (_run_task_with_lock,
 from aphrodite.common.utils import (_run_task_with_lock,
                                     get_aphrodite_instance_id,
                                     get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
                                     get_distributed_init_method, get_ip,
@@ -17,6 +17,7 @@ from aphrodite.executor.distributed_gpu_executor import (  # yapf: disable
     DistributedGPUExecutor, DistributedGPUExecutorAsync)
     DistributedGPUExecutor, DistributedGPUExecutorAsync)
 from aphrodite.executor.msgspec_utils import encode_hook
 from aphrodite.executor.msgspec_utils import encode_hook
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 if ray is not None:
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

+ 2 - 1
aphrodite/executor/ray_tpu_executor.py

@@ -8,13 +8,14 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
 from loguru import logger
 from loguru import logger
 
 
 import aphrodite.common.envs as envs
 import aphrodite.common.envs as envs
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_aphrodite_instance_id,
 from aphrodite.common.utils import (get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
                                     get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
 from aphrodite.executor.tpu_executor import TPUExecutor
 from aphrodite.executor.tpu_executor import TPUExecutor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 if ray is not None:
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

+ 2 - 1
aphrodite/executor/tpu_executor.py

@@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple
 import torch
 import torch
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 
 
 class TPUExecutor(ExecutorBase):
 class TPUExecutor(ExecutorBase):

+ 2 - 2
aphrodite/executor/xpu_executor.py

@@ -7,11 +7,11 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
                                      SpeculativeConfig)
-from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
-                                       SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
 from aphrodite.common.utils import make_async
 from aphrodite.common.utils import make_async
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.executor.gpu_executor import GPUExecutor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.task_handler.worker_base import WorkerBase
 from aphrodite.task_handler.worker_base import WorkerBase
 
 
 
 

+ 247 - 44
aphrodite/modeling/layers/sampler.py

@@ -1,10 +1,12 @@
 """A layer that samples the next tokens from the model's outputs."""
 """A layer that samples the next tokens from the model's outputs."""
 import itertools
 import itertools
 import warnings
 import warnings
+from dataclasses import dataclass
 from enum import IntEnum
 from enum import IntEnum
 from math import inf
 from math import inf
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 
+import msgspec
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from loguru import logger
 from loguru import logger
@@ -14,7 +16,8 @@ import aphrodite.common.envs as envs
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        PromptLogprobs, SampleLogprobs,
                                        PromptLogprobs, SampleLogprobs,
-                                       SamplerOutput, SequenceOutput)
+                                       SequenceOutput)
+from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 from aphrodite.triton_utils import HAS_TRITON
 from aphrodite.triton_utils import HAS_TRITON
 
 
 if HAS_TRITON:
 if HAS_TRITON:
@@ -27,6 +30,115 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
 # (num_token_ids, num_parent_ids) per sequence group.
 # (num_token_ids, num_parent_ids) per sequence group.
 SampleResultType = List[Tuple[List[int], List[int]]]
 SampleResultType = List[Tuple[List[int], List[int]]]
 
 
+# Types of temporary data structures used for
+# computing sample_result
+SampleMetadataType = Dict[SamplingType, Tuple[List[int],
+                                              List[SequenceGroupToSample]]]
+MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
+SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
+
+
+# Encapsulates temporary data structures for computing
+# sample_result.
+#
+# * For multi-step scheduling: must be returned
+#   by `Sampler.forward()` and used later to compute the pythonized
+#   sample_result
+#
+# * For single-step scheduling: consumed immediately
+#   inside `Sampler.forward()` to compute pythonized sample_result.
+@dataclass
+class SampleResultArgsType:
+    sample_metadata: SampleMetadataType
+    multinomial_samples: MultinomialSamplesType
+    sample_results_dict: SampleResultsDictType
+    sampling_metadata: SamplingMetadata
+    greedy_samples: Optional[torch.Tensor]
+    beam_search_logprobs: Optional[torch.Tensor]
+
+
+# Union of non-deferred (single-step scheduling)
+# vs deferred (multi-step scheduling)
+# sample result types
+MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
+
+# Abbreviation of the _sample() return type
+SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
+
+
+class SamplerOutput(
+        msgspec.Struct,
+        omit_defaults=True,  # type: ignore[call-arg]
+        array_like=True):  # type: ignore[call-arg]
+    """For each sequence group, we generate a list of SequenceOutput object,
+    each of which contains one possible candidate for the next token.
+    This data structure implements methods, so it can be used like a list, but
+    also has optional fields for device tensors.
+    """
+
+    outputs: List[CompletionSequenceGroupOutput]
+
+    # On-device tensor containing probabilities of each token.
+    sampled_token_probs: Optional[torch.Tensor] = None
+
+    # On-device tensor containing the logprobs of each token.
+    logprobs: Optional["torch.Tensor"] = None
+
+    # Holds either (1) the pythonized sampler result (single-step scheduling)
+    # or (2) what will be arguments for later deferred pythonization of the
+    # sampler result (muliti-step scheduling)
+    deferred_sample_results_args: Optional[SampleResultArgsType] = None
+
+    # On-device tensor containing the sampled token ids.
+    sampled_token_ids: Optional[torch.Tensor] = None
+    # CPU tensor containing the sampled token ids. Used during multi-step to
+    # return the sampled token ids from last rank to AsyncLLMEngine to be
+    # 'broadcasted' to all other PP ranks for next step.
+    sampled_token_ids_cpu: Optional[torch.Tensor] = None
+
+    # Spec decode metrics populated by workers.
+    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
+
+    # Optional last hidden states from the model.
+    hidden_states: Optional[torch.Tensor] = None
+
+    # Optional prefill hidden states from the model
+    # (used for models like EAGLE).
+    prefill_hidden_states: Optional[torch.Tensor] = None
+
+    # Time taken in the forward pass for this across all workers
+    model_forward_time: Optional[float] = None
+
+    # Time taken in the model execute function. This will include model forward,
+    # block/sync across workers, cpu-gpu sync time and sampling time.
+    model_execute_time: Optional[float] = None
+
+    def __getitem__(self, idx: int):
+        return self.outputs[idx]
+
+    def __setitem__(self, idx: int, value):
+        self.outputs[idx] = value
+
+    def __len__(self):
+        return len(self.outputs)
+
+    def __eq__(self, other: object):
+        return isinstance(other,
+                          self.__class__) and self.outputs == other.outputs
+
+    def __repr__(self) -> str:
+        """Show the shape of a tensor instead of its values to reduce noise.
+        """
+        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
+                                    else self.sampled_token_probs.shape)
+        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
+                                  self.sampled_token_ids.shape)
+        return (
+            f"SamplerOutput(outputs={self.outputs}, "
+            f"sampled_token_probs={sampled_token_probs_repr}, "
+            f"sampled_token_ids={sampled_token_ids_repr}, "
+            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
+
 # There isn't a "safe" temperature range for fp16 logits.
 # There isn't a "safe" temperature range for fp16 logits.
 # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
 # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
 # that this temperature well-uses the fp16 space after the logits are offset.
 # that this temperature well-uses the fp16 space after the logits are offset.
@@ -135,6 +247,18 @@ class Sampler(nn.Module):
         sampling_metadata: SamplingMetadata,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
     ) -> Optional[SamplerOutput]:
         """
         """
+        Single-step scheduling:
+        * Perform GPU-side sampling computation & compute
+          GPU-side logprobs tensor
+        * Pythonize sampling result & logprobs tensor
+        Multi-step scheduling:
+        * Perform GPU-side sampling computation & compute
+          GPU-side logprobs tensor
+        * Defer Pythonization of sampling result & logprobs
+          tensor
+        * Encapsulate arguments required for deferred Pythonization
+          in the :class:`SamplerOutput` structure
+
         Args:
         Args:
             logits: (num_tokens, vocab_size).
             logits: (num_tokens, vocab_size).
             sampling_metadata: Metadata for sampling.
             sampling_metadata: Metadata for sampling.
@@ -425,7 +549,7 @@ class Sampler(nn.Module):
         logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
         logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
 
 
         # Sample the next tokens.
         # Sample the next tokens.
-        sample_results, maybe_sampled_tokens_tensor = _sample(
+        maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
             probs,
             probs,
             logprobs,
             logprobs,
             sampling_metadata,
             sampling_metadata,
@@ -435,20 +559,28 @@ class Sampler(nn.Module):
         )
         )
 
 
         if self.include_gpu_probs_tensor:
         if self.include_gpu_probs_tensor:
+            # Since we will defer sampler result Pythonization,
+            # preserve GPU-side tensors in support of later
+            # deferred pythonization of logprobs
             assert maybe_sampled_tokens_tensor is not None
             assert maybe_sampled_tokens_tensor is not None
             on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
             on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
         else:
         else:
+            # Since Pythonization has already happened, don't preserve
+            # GPU-side tensors.
             on_device_tensors = None
             on_device_tensors = None
 
 
         # Get the logprobs query results.
         # Get the logprobs query results.
         prompt_logprobs = None
         prompt_logprobs = None
         sample_logprobs = None
         sample_logprobs = None
         if not sampling_metadata.skip_sampler_cpu_output:
         if not sampling_metadata.skip_sampler_cpu_output:
-            prompt_logprobs, sample_logprobs = _get_logprobs(
-                logprobs, sampling_metadata, sample_results)
+            # Pythonize logprobs now (GPU -> CPU); do not defer.
+            assert not isinstance(maybe_deferred_sample_results,
+                                  SampleResultArgsType)
+            prompt_logprobs, sample_logprobs = get_logprobs(
+                logprobs, sampling_metadata, maybe_deferred_sample_results)
 
 
         return _build_sampler_output(
         return _build_sampler_output(
-            sample_results,
+            maybe_deferred_sample_results,
             sampling_metadata,
             sampling_metadata,
             prompt_logprobs,
             prompt_logprobs,
             sample_logprobs,
             sample_logprobs,
@@ -1205,6 +1337,57 @@ def _top_k_top_p_multinomial_with_kernels(
     return batch_next_token_ids.view(-1, num_samples)
     return batch_next_token_ids.view(-1, num_samples)
 
 
 
 
+def get_pythonized_sample_results(
+        sample_result_args: SampleResultArgsType) -> SampleResultType:
+    """This function consumes GPU-side sampler results and computes
+    Pythonized CPU-side sampler results (GPU -> CPU sync.)
+    Single-step scheduling: this function is invoked at sampling-time
+    for immediate Pythonization.
+    Multi-step scheduling: Pythonization is deferred until after multiple
+    GPU-side steps have been completed.
+
+    Args:
+      sample_result_args: GPU-side inputs to the Pythonization process
+    Returns:
+      Pythonized sampler results
+    """
+
+    (
+        sample_metadata,
+        sampling_metadata,
+        greedy_samples,
+        multinomial_samples,
+        beam_search_logprobs,
+        sample_results_dict,
+    ) = (
+        sample_result_args.sample_metadata,
+        sample_result_args.sampling_metadata,
+        sample_result_args.greedy_samples,
+        sample_result_args.multinomial_samples,
+        sample_result_args.beam_search_logprobs,
+        sample_result_args.sample_results_dict,
+    )
+
+    for sampling_type in SamplingType:
+        if sampling_type not in sample_metadata:
+            continue
+        (seq_group_id, seq_groups) = sample_metadata[sampling_type]
+        if sampling_type == SamplingType.GREEDY:
+            sample_results = _greedy_sample(seq_groups, greedy_samples)
+        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
+            sample_results = _random_sample(seq_groups,
+                                            multinomial_samples[sampling_type])
+        elif sampling_type == SamplingType.BEAM:
+            sample_results = _beam_search_sample(seq_groups,
+                                                 beam_search_logprobs)
+        sample_results_dict.update(zip(seq_group_id, sample_results))
+
+    return [
+        sample_results_dict.get(i, ([], []))
+        for i in range(len(sampling_metadata.seq_groups))
+    ]
+
+
 def _sample_with_torch(
 def _sample_with_torch(
     probs: torch.Tensor,
     probs: torch.Tensor,
     logprobs: torch.Tensor,
     logprobs: torch.Tensor,
@@ -1212,7 +1395,18 @@ def _sample_with_torch(
     sampling_tensors: SamplingTensors,
     sampling_tensors: SamplingTensors,
     include_gpu_probs_tensor: bool,
     include_gpu_probs_tensor: bool,
     modify_greedy_probs: bool,
     modify_greedy_probs: bool,
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+) -> SampleReturnType:
+    """Torch-oriented _sample() implementation.
+
+    Single-step scheduling: 
+    * Perform GPU-side sampling computation
+    * Immediately Pythonize sampling result
+
+    Multi-step scheduling:
+    * Perform GPU-side sampling computation
+    * Defer Pythonization & preserve GPU-side
+      tensors required for Pythonization
+    """
     categorized_seq_group_ids = {t: [] for t in SamplingType}
     categorized_seq_group_ids = {t: [] for t in SamplingType}
     categorized_sample_indices = sampling_metadata.categorized_sample_indices
     categorized_sample_indices = sampling_metadata.categorized_sample_indices
     for i, seq_group in enumerate(sampling_metadata.seq_groups):
     for i, seq_group in enumerate(sampling_metadata.seq_groups):
@@ -1220,9 +1414,11 @@ def _sample_with_torch(
         sampling_type = sampling_params.sampling_type
         sampling_type = sampling_params.sampling_type
         categorized_seq_group_ids[sampling_type].append(i)
         categorized_seq_group_ids[sampling_type].append(i)
 
 
-    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
-    sample_metadata = {}
-    multinomial_samples = {}
+    sample_results_dict: SampleResultsDictType = {}
+    sample_metadata: SampleMetadataType = {}
+    multinomial_samples: MultinomialSamplesType = {}
+    greedy_samples: Optional[torch.Tensor] = None
+    beam_search_logprobs: Optional[torch.Tensor] = None
     # Create output tensor for sampled token ids.
     # Create output tensor for sampled token ids.
     if include_gpu_probs_tensor:
     if include_gpu_probs_tensor:
         sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
         sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
@@ -1293,32 +1489,29 @@ def _sample_with_torch(
         else:
         else:
             raise ValueError(f"Unsupported sampling type: {sampling_type}")
             raise ValueError(f"Unsupported sampling type: {sampling_type}")
 
 
-    # GPU<->CPU sync happens in the loop below.
-    # This also converts the sample output to Python objects.
+    # Encapsulate arguments for computing Pythonized sampler
+    # results, whether deferred or otherwise.
+    maybe_deferred_args = SampleResultArgsType(
+        sampling_metadata=sampling_metadata,
+        sample_metadata=sample_metadata,
+        multinomial_samples=multinomial_samples,
+        greedy_samples=greedy_samples,
+        beam_search_logprobs=beam_search_logprobs,
+        sample_results_dict=sample_results_dict)
+
     if not sampling_metadata.skip_sampler_cpu_output:
     if not sampling_metadata.skip_sampler_cpu_output:
-        for sampling_type in SamplingType:
-            if sampling_type not in sample_metadata:
-                continue
-            (seq_group_id, seq_groups) = sample_metadata[sampling_type]
-            if sampling_type == SamplingType.GREEDY:
-                sample_results = _greedy_sample(seq_groups, greedy_samples)
-            elif sampling_type in (SamplingType.RANDOM,
-                                   SamplingType.RANDOM_SEED):
-                sample_results = _random_sample(
-                    seq_groups, multinomial_samples[sampling_type])
-            elif sampling_type == SamplingType.BEAM:
-                sample_results = _beam_search_sample(seq_groups,
-                                                     beam_search_logprobs)
-            sample_results_dict.update(zip(seq_group_id, sample_results))
-
-        sample_results = [
-            sample_results_dict.get(i, ([], []))
-            for i in range(len(sampling_metadata.seq_groups))
-        ]
+        # GPU<->CPU sync happens here.
+        # This also converts the sampler output to a Python object.
+        # Return Pythonized sampler result & sampled token ids
+        return get_pythonized_sample_results(
+            maybe_deferred_args), sampled_token_ids_tensor
     else:
     else:
-        sample_results = []
-
-    return sample_results, sampled_token_ids_tensor
+        # Defer sampler result Pythonization; return deferred
+        # Pythonization args & sampled token ids
+        return (
+            maybe_deferred_args,
+            sampled_token_ids_tensor,
+        )
 
 
 
 
 def _sample_with_triton_kernel(
 def _sample_with_triton_kernel(
@@ -1396,10 +1589,13 @@ def _sample_with_triton_kernel(
 
 
 
 
 def _sample(
 def _sample(
-    probs: torch.Tensor, logprobs: torch.Tensor,
-    sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
-    include_gpu_probs_tensor: bool, modify_greedy_probs: bool
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+    probs: torch.Tensor,
+    logprobs: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+    sampling_tensors: SamplingTensors,
+    include_gpu_probs_tensor: bool,
+    modify_greedy_probs: bool,
+) -> SampleReturnType:
     """
     """
     Args:
     Args:
         probs: (num_query_tokens_in_batch, num_vocab)
         probs: (num_query_tokens_in_batch, num_vocab)
@@ -1441,7 +1637,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
     return (x > vals[:, None]).long().sum(1).add_(1)
     return (x > vals[:, None]).long().sum(1).add_(1)
 
 
 
 
-def _get_logprobs(
+def get_logprobs(
     logprobs: torch.Tensor,
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
     sampling_metadata: SamplingMetadata,
     sample_results: List[Tuple[List[int], List[int]]],
     sample_results: List[Tuple[List[int], List[int]]],
@@ -1755,7 +1951,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
 
 
 
 
 def _build_sampler_output(
 def _build_sampler_output(
-    sample_results: SampleResultType,
+    maybe_deferred_sample_results: MaybeDeferredSampleResultType,
     sampling_metadata: SamplingMetadata,
     sampling_metadata: SamplingMetadata,
     prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
     prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
     sample_logprobs: Optional[List[SampleLogprobs]],
     sample_logprobs: Optional[List[SampleLogprobs]],
@@ -1771,14 +1967,21 @@ def _build_sampler_output(
             speculative decoding rejection sampling.
             speculative decoding rejection sampling.
     """
     """
     sampler_output: List[CompletionSequenceGroupOutput] = []
     sampler_output: List[CompletionSequenceGroupOutput] = []
-    if not skip_sampler_cpu_output:
+
+    if skip_sampler_cpu_output:
+        assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
+        deferred_sample_results_args = maybe_deferred_sample_results
+    else:
         assert prompt_logprobs is not None
         assert prompt_logprobs is not None
         assert sample_logprobs is not None
         assert sample_logprobs is not None
+        assert not isinstance(maybe_deferred_sample_results,
+                              SampleResultArgsType)
+        deferred_sample_results_args = None
 
 
         for (seq_group, sample_result, group_prompt_logprobs,
         for (seq_group, sample_result, group_prompt_logprobs,
              group_sample_logprobs) in zip(sampling_metadata.seq_groups,
              group_sample_logprobs) in zip(sampling_metadata.seq_groups,
-                                           sample_results, prompt_logprobs,
-                                           sample_logprobs):
+                                           maybe_deferred_sample_results,
+                                           prompt_logprobs, sample_logprobs):
             seq_ids = seq_group.seq_ids
             seq_ids = seq_group.seq_ids
             next_token_ids, parent_ids = sample_result
             next_token_ids, parent_ids = sample_result
             seq_outputs: List[SequenceOutput] = []
             seq_outputs: List[SequenceOutput] = []
@@ -1802,7 +2005,7 @@ def _build_sampler_output(
         sampled_token_probs=sampled_token_probs,
         sampled_token_probs=sampled_token_probs,
         sampled_token_ids=sampled_token_ids,
         sampled_token_ids=sampled_token_ids,
         logprobs=logprobs_tensor,
         logprobs=logprobs_tensor,
-    )
+        deferred_sample_results_args=deferred_sample_results_args)
 
 
 
 
 def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
 def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:

+ 1 - 2
aphrodite/modeling/model_loader/neuron.py

@@ -10,9 +10,8 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.common.config import (ModelConfig, ParallelConfig,
 from aphrodite.common.config import (ModelConfig, ParallelConfig,
                                      SchedulerConfig)
                                      SchedulerConfig)
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 
 
 TORCH_DTYPE_TO_NEURON_AMP = {
 TORCH_DTYPE_TO_NEURON_AMP = {

+ 1 - 2
aphrodite/modeling/model_loader/openvino.py

@@ -13,10 +13,9 @@ from torch import nn
 import aphrodite.common.envs as envs
 import aphrodite.common.envs as envs
 from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
 from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
 from aphrodite.common.config import DeviceConfig, ModelConfig
 from aphrodite.common.config import DeviceConfig, ModelConfig
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
 from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
                                                         _prune_hidden_states)
                                                         _prune_hidden_states)
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 
 
 APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS = (
 APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS = (

+ 2 - 2
aphrodite/modeling/models/arctic.py

@@ -7,7 +7,7 @@ from torch import nn
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
@@ -20,7 +20,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/baichuan.py

@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/bart.py

@@ -25,14 +25,14 @@ from transformers import BartConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata, AttentionType
 from aphrodite.attention import Attention, AttentionMetadata, AttentionType
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 3
aphrodite/modeling/models/blip2.py

@@ -9,13 +9,12 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.opt import OPTModel
 from aphrodite.modeling.models.opt import OPTModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata

+ 2 - 2
aphrodite/modeling/models/bloom.py

@@ -25,7 +25,7 @@ from transformers import BloomConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 3
aphrodite/modeling/models/chameleon.py

@@ -11,8 +11,7 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.common.utils import print_warning_once
 from aphrodite.common.utils import print_warning_once
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
@@ -24,7 +23,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/chatglm.py

@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -19,7 +19,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/commandr.py

@@ -29,7 +29,7 @@ from transformers import CohereConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/dbrx.py

@@ -6,7 +6,7 @@ import torch.nn as nn
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
@@ -16,7 +16,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/deepseek.py

@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/deepseek_v2.py

@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/modeling/models/eagle.py

@@ -4,8 +4,9 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.backends.abstract import AttentionMetadata
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/exaone.py

@@ -30,7 +30,7 @@ from torch import nn
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_hip
 from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_pp_group,
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_rank,
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/falcon.py

@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
@@ -38,7 +38,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/fuyu.py

@@ -27,11 +27,11 @@ from transformers import FuyuConfig, FuyuImageProcessor
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
 from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata

+ 2 - 2
aphrodite/modeling/models/gemma.py

@@ -24,7 +24,7 @@ from transformers import GemmaConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
 from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gemma2.py

@@ -24,7 +24,7 @@ from transformers import Gemma2Config
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
 from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt2.py

@@ -25,14 +25,14 @@ from transformers import GPT2Config
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt_bigcode.py

@@ -26,14 +26,14 @@ from transformers import GPTBigCodeConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt_j.py

@@ -24,7 +24,7 @@ from transformers import GPTJConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -32,7 +32,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt_neox.py

@@ -24,7 +24,7 @@ from transformers import GPTNeoXConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -32,7 +32,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/internlm2.py

@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -16,7 +16,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/modeling/models/internvl.py

@@ -16,8 +16,9 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.intern_vit import InternVisionModel
 from aphrodite.modeling.models.intern_vit import InternVisionModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata

+ 2 - 2
aphrodite/modeling/models/jais.py

@@ -27,14 +27,14 @@ from torch import nn
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/jamba.py

@@ -11,7 +11,7 @@ from transformers import JambaConfig
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.layer import Attention
 from aphrodite.attention.layer import Attention
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 # yapf: disable
 # yapf: disable
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
@@ -28,7 +28,7 @@ from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
                                              causal_conv1d_update,
                                              causal_conv1d_update,
                                              selective_scan_fn,
                                              selective_scan_fn,
                                              selective_state_update)
                                              selective_state_update)
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/llama.py

@@ -29,7 +29,7 @@ from transformers import LlamaConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_hip
 from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_pp_group,
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 1
aphrodite/modeling/models/llava.py

@@ -8,9 +8,10 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal import MULTIMODAL_REGISTRY

+ 2 - 1
aphrodite/modeling/models/llava_next.py

@@ -12,9 +12,10 @@ from typing_extensions import NotRequired
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_list_of
 from aphrodite.common.utils import is_list_of
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal import MULTIMODAL_REGISTRY

+ 2 - 2
aphrodite/modeling/models/mamba.py

@@ -10,7 +10,7 @@ from transformers import MambaConfig
 
 
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -23,7 +23,7 @@ from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
     causal_conv1d_fn, causal_conv1d_update)
     causal_conv1d_fn, causal_conv1d_update)
 from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
 from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
     selective_scan_fn, selective_state_update)
     selective_scan_fn, selective_state_update)
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 1 - 1
aphrodite/modeling/models/medusa.py

@@ -3,8 +3,8 @@ from typing import Iterable, List, Optional, Tuple
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/minicpm.py

@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
@@ -44,7 +44,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 3
aphrodite/modeling/models/minicpmv.py

@@ -39,13 +39,12 @@ from transformers.configuration_utils import PretrainedConfig
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ReplicatedLinear
 from aphrodite.modeling.layers.linear import ReplicatedLinear
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
 from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/mixtral.py

@@ -29,7 +29,7 @@ from transformers import MixtralConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_pp_group,
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.fused_moe import FusedMoE
@@ -39,7 +39,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -31,7 +31,7 @@ from transformers import MixtralConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
                                    tensor_model_parallel_all_reduce)
@@ -41,7 +41,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 1 - 2
aphrodite/modeling/models/mlp_speculator.py

@@ -4,9 +4,8 @@ from typing import Iterable, List, Tuple
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/mpt.py

@@ -8,7 +8,7 @@ import torch.nn as nn
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -16,7 +16,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/nemotron.py

@@ -30,7 +30,7 @@ from transformers import NemotronConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_pp_group,
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -39,7 +39,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/olmo.py

@@ -29,7 +29,7 @@ from transformers import OlmoConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/olmoe.py

@@ -18,7 +18,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -27,7 +27,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/opt.py

@@ -25,7 +25,7 @@ from transformers import OPTConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               ReplicatedLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/orion.py

@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -20,7 +20,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/paligemma.py

@@ -8,10 +8,10 @@ from transformers import PaliGemmaConfig
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.gemma import GemmaModel
 from aphrodite.modeling.models.gemma import GemmaModel
 from aphrodite.modeling.models.gemma2 import Gemma2Model
 from aphrodite.modeling.models.gemma2 import Gemma2Model

+ 2 - 2
aphrodite/modeling/models/persimmon.py

@@ -30,14 +30,14 @@ from transformers.activations import ReLUSquaredActivation
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/phi.py

@@ -43,7 +43,7 @@ from transformers import PhiConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -51,7 +51,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/phi3_small.py

@@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -15,7 +15,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/phi3v.py

@@ -29,11 +29,11 @@ from transformers import CLIPVisionConfig, PretrainedConfig
 
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_list_of
 from aphrodite.common.utils import is_list_of
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.clip import CLIPVisionModel

+ 2 - 2
aphrodite/modeling/models/qwen.py

@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -21,7 +21,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/qwen2.py

@@ -30,7 +30,7 @@ from transformers import Qwen2Config
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_rank,
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/qwen2_moe.py

@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import print_warning_once
 from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import (get_pp_group,
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size,
                                    get_tensor_model_parallel_world_size,
@@ -45,7 +45,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/solar.py

@@ -29,7 +29,7 @@ from torch import nn
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_hip
 from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_pp_group,
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_rank,
@@ -41,7 +41,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/stablelm.py

@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -35,7 +35,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/starcoder2.py

@@ -26,7 +26,7 @@ from transformers import Starcoder2Config
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -34,7 +34,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/modeling/models/ultravox.py

@@ -20,12 +20,13 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
 from aphrodite.attention import AttentionMetadata
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
 from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SamplerOutput, SequenceData)
+                                       SequenceData)
 from aphrodite.inputs import INPUT_REGISTRY
 from aphrodite.inputs import INPUT_REGISTRY
 from aphrodite.inputs.data import LLMInputs
 from aphrodite.inputs.data import LLMInputs
 from aphrodite.inputs.registry import InputContext
 from aphrodite.inputs.registry import InputContext
 from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
 from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.modeling.models.utils import (filter_weights,
 from aphrodite.modeling.models.utils import (filter_weights,

+ 2 - 2
aphrodite/modeling/models/xverse.py

@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/quantization/gguf.py

@@ -125,7 +125,8 @@ class GGUFLinearMethod(LinearMethodBase):
                     offset:offset +
                     offset:offset +
                     shard_size[id][0], :shard_size[id][1]].contiguous()
                     shard_size[id][0], :shard_size[id][1]].contiguous()
                 qweight_type = layer.qweight_type.shard_weight_type[id]
                 qweight_type = layer.qweight_type.shard_weight_type[id]
-                result.append(_fuse_mul_mat(x, shard_weight, qweight_type).contiguous())
+                result.append(_fuse_mul_mat(x, shard_weight,
+                                            qweight_type).contiguous())
                 offset += shard_size[id][0]
                 offset += shard_size[id][0]
             out = torch.cat(result, dim=-1)
             out = torch.cat(result, dim=-1)
         else:
         else:

+ 3 - 3
aphrodite/spec_decode/batch_expansion.py

@@ -6,9 +6,9 @@ import torch
 
 
 from aphrodite import SamplingParams
 from aphrodite import SamplingParams
 from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
 from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       ExecuteModelRequest, SamplerOutput,
-                                       SequenceData, SequenceGroupMetadata,
-                                       get_all_seq_ids)
+                                       ExecuteModelRequest, SequenceData,
+                                       SequenceGroupMetadata, get_all_seq_ids)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeScorer,
                                               SpeculativeScorer,
                                               SpeculativeScores)
                                               SpeculativeScores)

+ 2 - 2
aphrodite/spec_decode/draft_model_runner.py

@@ -15,8 +15,8 @@ except ModuleNotFoundError:
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
                                      PromptAdapterConfig, SchedulerConfig)
-from aphrodite.common.sequence import (ExecuteModelRequest,
-                                       IntermediateTensors, SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.multimodal import MultiModalInputs
 from aphrodite.multimodal import MultiModalInputs
 from aphrodite.task_handler.model_runner import (
 from aphrodite.task_handler.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
     ModelInputForGPUWithSamplingMetadata, ModelRunner)

+ 2 - 1
aphrodite/spec_decode/medusa_worker.py

@@ -3,8 +3,9 @@ from typing import List, Optional, Set, Tuple
 
 
 import torch
 import torch
 
 
-from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+from aphrodite.common.sequence import (ExecuteModelRequest,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase

+ 2 - 1
aphrodite/spec_decode/mlp_speculator_worker.py

@@ -2,8 +2,9 @@ from typing import List, Optional, Set, Tuple
 
 
 import torch
 import torch
 
 
-from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+from aphrodite.common.sequence import (ExecuteModelRequest,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase

+ 2 - 2
aphrodite/spec_decode/multi_step_worker.py

@@ -5,8 +5,8 @@ from typing import Dict, List, Set, Tuple
 import torch
 import torch
 
 
 from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates,
 from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates,
-                                       SamplerOutput, SequenceData,
-                                       SequenceGroupMetadata)
+                                       SequenceData, SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
 from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeProposer)
                                               SpeculativeProposer)

+ 2 - 1
aphrodite/spec_decode/ngram_worker.py

@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple
 
 
 import torch
 import torch
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
 from aphrodite.spec_decode.top1_proposer import Top1Proposer

+ 2 - 1
aphrodite/spec_decode/proposer_worker_base.py

@@ -1,7 +1,8 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import List, Optional, Set, Tuple
 from typing import List, Optional, Set, Tuple
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 
 

+ 2 - 1
aphrodite/spec_decode/smaller_tp_proposer_worker.py

@@ -3,10 +3,11 @@ from typing import List, Optional, Set, Tuple
 import torch
 import torch
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed.parallel_state import (get_tp_group,
 from aphrodite.distributed.parallel_state import (get_tp_group,
                                                   init_model_parallel_group,
                                                   init_model_parallel_group,
                                                   patch_tensor_parallel_group)
                                                   patch_tensor_parallel_group)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase

+ 2 - 2
aphrodite/spec_decode/spec_decode_worker.py

@@ -8,11 +8,11 @@ from loguru import logger
 from aphrodite.common.config import ParallelConfig, SpeculativeConfig
 from aphrodite.common.config import ParallelConfig, SpeculativeConfig
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        ExecuteModelRequest, HiddenStates,
                                        ExecuteModelRequest, HiddenStates,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       get_all_seq_ids,
+                                       SequenceGroupMetadata, get_all_seq_ids,
                                        get_all_seq_ids_and_request_ids)
                                        get_all_seq_ids_and_request_ids)
 from aphrodite.distributed.communication_op import broadcast_tensor_dict
 from aphrodite.distributed.communication_op import broadcast_tensor_dict
 from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
 from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.spec_decode_base_sampler import (
 from aphrodite.modeling.layers.spec_decode_base_sampler import (
     SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
     SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
 from aphrodite.modeling.layers.typical_acceptance_sampler import (
 from aphrodite.modeling.layers.typical_acceptance_sampler import (

+ 2 - 1
aphrodite/spec_decode/top1_proposer.py

@@ -2,8 +2,9 @@ from typing import List, Optional, Set, Tuple
 
 
 import torch
 import torch
 
 
-from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+from aphrodite.common.sequence import (ExecuteModelRequest,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeProposer)
                                               SpeculativeProposer)
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase

+ 2 - 2
aphrodite/spec_decode/util.py

@@ -5,8 +5,8 @@ from typing import Dict, List, Optional, Sequence, Tuple
 import torch
 import torch
 
 
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       SequenceOutput)
+                                       SequenceGroupMetadata, SequenceOutput)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 SeqId = int
 SeqId = int
 
 

+ 2 - 1
aphrodite/task_handler/cpu_model_runner.py

@@ -8,9 +8,10 @@ from aphrodite.attention import AttentionMetadata, get_attn_backend
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
                                      PromptAdapterConfig, SchedulerConfig)
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import make_tensor_with_pad
 from aphrodite.common.utils import make_tensor_with_pad
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,

+ 2 - 1
aphrodite/task_handler/enc_dec_model_runner.py

@@ -17,10 +17,11 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      PromptAdapterConfig, SchedulerConfig)
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
 from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
-                                       SamplerOutput, SequenceGroupMetadata)
+                                       SequenceGroupMetadata)
 from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
 from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
                                     make_tensor_with_pad)
                                     make_tensor_with_pad)
 from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
 from aphrodite.task_handler.model_runner import (
 from aphrodite.task_handler.model_runner import (

+ 2 - 1
aphrodite/task_handler/model_runner.py

@@ -23,7 +23,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
 from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
                                     async_tensor_h2d, flatten_2d_lists, is_hip,
                                     async_tensor_h2d, flatten_2d_lists, is_hip,
@@ -36,6 +36,7 @@ from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.lora.layers import LoRAMapping
 from aphrodite.lora.layers import LoRAMapping
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
 from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.models.interfaces import (supports_lora,
 from aphrodite.modeling.models.interfaces import (supports_lora,

+ 2 - 1
aphrodite/task_handler/model_runner_base.py

@@ -5,8 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
 
 
 import torch
 import torch
 
 
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.platforms import current_platform
 from aphrodite.platforms import current_platform
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:

+ 147 - 29
aphrodite/task_handler/multi_step_model_runner.py

@@ -1,7 +1,8 @@
 import dataclasses
 import dataclasses
 import functools
 import functools
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
+from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
+                    Union)
 
 
 try:
 try:
     from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
     from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
@@ -16,9 +17,12 @@ import torch
 from aphrodite import _custom_ops as ops
 from aphrodite import _custom_ops as ops
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        IntermediateTensors, Logprob,
                                        IntermediateTensors, Logprob,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       SequenceOutput)
+                                       SequenceGroupMetadata, SequenceOutput)
 from aphrodite.distributed import get_pp_group
 from aphrodite.distributed import get_pp_group
+from aphrodite.modeling.layers.sampler import (PromptLogprobs, SampleLogprobs,
+                                               SamplerOutput, SamplingMetadata,
+                                               get_logprobs,
+                                               get_pythonized_sample_results)
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.task_handler.model_runner import (
 from aphrodite.task_handler.model_runner import (
     GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
     GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
@@ -50,6 +54,8 @@ class ModelOutput:
     sampler_output_ready_event: torch.cuda.Event
     sampler_output_ready_event: torch.cuda.Event
     sampled_token_ids: Optional[torch.Tensor] = None
     sampled_token_ids: Optional[torch.Tensor] = None
     pythonized: bool = False
     pythonized: bool = False
+    # On-device tensor containing the logprobs of each token.
+    logprobs: Optional["torch.Tensor"] = None
 
 
     def pythonize(
     def pythonize(
         self,
         self,
@@ -85,7 +91,9 @@ class ModelOutput:
     ) -> bool:
     ) -> bool:
         """
         """
         If blocking is set, will block until the forward pass for the output is
         If blocking is set, will block until the forward pass for the output is
-        ready and pythonize the output.
+        ready and pythonize the output. Upon completing Pythonization, erases
+        self.logprobs (note that a non-blocking call that is performed when
+        the sampler output is not yet ready, will not erase self.logprobs.)
         """
         """
         assert self.sampled_token_ids is not None
         assert self.sampled_token_ids is not None
         if not blocking and not self.sampler_output_ready_event.query():
         if not blocking and not self.sampler_output_ready_event.query():
@@ -97,8 +105,15 @@ class ModelOutput:
                 input_metadata,
                 input_metadata,
                 self.sampler_output,
                 self.sampler_output,
                 pinned_sampled_token_buffer,
                 pinned_sampled_token_buffer,
-                self.sampled_token_ids,
-            )
+                self.sampled_token_ids, self.logprobs)
+
+        # Erase the logprobs GPU-side tensor.
+        # Note that although _pythonize_sampler_output() runs in its
+        # own CUDA stream, nonetheless _pythonize_sampler_output()
+        # cannot return until Pythonization is complete; therefore
+        # we know that by the time the CPU reaches this point,
+        # `self.logprobs` is no longer needed.
+        self.logprobs = None
         return True
         return True
 
 
 
 
@@ -355,11 +370,12 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
                 ModelOutput(
                 ModelOutput(
                     output[0],
                     output[0],
                     output_ready_event,
                     output_ready_event,
-                    output[0].sampled_token_ids,
-                    False,
-                )
-            )
-            # make sure we dont try to serialize any GPU tensors
+                    output[0].sampled_token_ids, False,
+                            output[0].logprobs))
+
+            # These GPU tensors are not required by multi-step;
+            # erase them to ensure they are not pythonized or
+            # transferred to CPU
             output[0].sampled_token_ids = None
             output[0].sampled_token_ids = None
             output[0].sampled_token_probs = None
             output[0].sampled_token_probs = None
             output[0].logprobs = None
             output[0].logprobs = None
@@ -464,14 +480,72 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
         return self._base_model_runner.vocab_size
         return self._base_model_runner.vocab_size
 
 
 
 
+DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
+                                   Optional[List[SampleLogprobs]]]
+
+
+def deferred_pythonize_logprobs(
+    output: SamplerOutput,
+    sampling_metadata: SamplingMetadata,
+    logprobs_tensor: Optional[torch.Tensor],
+) -> DeferredLogprobsReturnType:
+    """Perform deferred logprob Pythonization.
+    1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
+    2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
+       utilizing  the Pythonized sampler result computed in step 1.
+    
+    These deferred computations are not required for single-step scheduling
+    or the `profile_run()` phase of multi-step scheduling.
+    Args:
+        output: sampler output (under deferred Pythonization)
+        sampling_metadata
+        
+    Returns:
+        prompt_logprobs (CPU), sample_logprobs (CPU)
+    """
+
+    # - Deferred pythonization of sample result
+    sampler_result = get_pythonized_sample_results(
+        output.deferred_sample_results_args)
+
+    # - Erase the GPU-side deferred sample_result
+    #   computation args to ensure it is never
+    #   pythonized or transferred to CPU
+    output.deferred_sample_results_args = None
+
+    # - Deferred pythonization of logprobs
+    (
+        prompt_logprobs,
+        sample_logprobs,
+    ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
+    assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
+    assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
+
+    return prompt_logprobs, sample_logprobs
+
+
 def _pythonize_sampler_output(
 def _pythonize_sampler_output(
     model_input: StatefulModelInput,
     model_input: StatefulModelInput,
     output: SamplerOutput,
     output: SamplerOutput,
     pinned_sampled_token_buffer: torch.Tensor,
     pinned_sampled_token_buffer: torch.Tensor,
     sampled_token_ids: torch.Tensor,
     sampled_token_ids: torch.Tensor,
+    logprobs_tensor: Optional[torch.Tensor],
 ) -> None:
 ) -> None:
-    """This function is only called when the output tensors are ready.
-    See ModelOutput
+    """ This function is only called when the output tensors are ready. 
+    See :class:`ModelOutput`. 
+    
+    Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, 
+    adding a Pythonized output data structure
+    (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
+    Args:
+      model_input
+      output: sampler output
+      pinned_sampled_token_token_buffer: CPU-side pinned memory
+                                         (receives copy of
+                                         GPU-side token buffer.)
+      sampled_token_ids: GPU-side token buffer
+      logprobs_tensor: GPU-side tensor containing 
+                       logprobs computed during sampling
     """
     """
     assert model_input.frozen_model_input is not None
     assert model_input.frozen_model_input is not None
     frozen_model_input = model_input.frozen_model_input
     frozen_model_input = model_input.frozen_model_input
@@ -484,26 +558,70 @@ def _pythonize_sampler_output(
     # this will not block as the tensors are already on CPU
     # this will not block as the tensors are already on CPU
     samples_list = pinned_buffer.tolist()
     samples_list = pinned_buffer.tolist()
     sampling_metadata = frozen_model_input.sampling_metadata
     sampling_metadata = frozen_model_input.sampling_metadata
-    for seq_group, sample_result in zip(
-        sampling_metadata.seq_groups, samples_list
-    ):
+    skip_sampler_cpu_output = (
+        frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
+
+    # We are guaranteed output tensors are ready, so it is safe to
+    # pythonize the sampler output & obtain CPU-side logprobs.
+    #
+    # However this computation may be skipped entirely
+    # if no pythonization was deferred.
+    seq_groups = sampling_metadata.seq_groups
+    logprobs_are_requested = any([
+        sg.sampling_params.logprobs is not None
+        or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
+    ])
+    do_pythonize_logprobs = (skip_sampler_cpu_output
+                             and logprobs_are_requested)
+    (
+        prompt_logprobs,
+        sample_logprobs,
+    ) = (deferred_pythonize_logprobs(output, sampling_metadata,
+                                     logprobs_tensor)
+         if do_pythonize_logprobs else (None, None))
+
+    for sgdx, (seq_group,
+               sample_result) in enumerate(zip(seq_groups, samples_list)):
+
+        if do_pythonize_logprobs:
+            assert prompt_logprobs is not None
+            assert sample_logprobs is not None
+
+            (
+                group_prompt_logprobs,
+                group_sample_logprobs,
+            ) = (  # Utilize deferred pythonization results
+                prompt_logprobs[sgdx],
+                sample_logprobs[sgdx],
+            )
+        elif logprobs_are_requested:
+            (
+                group_prompt_logprobs,
+                group_sample_logprobs,
+            ) = (
+                # profile_run: use already-computed logprobs
+                output.outputs[sgdx].prompt_logprobs,
+                [sample.logprobs for sample in output.outputs[sgdx].samples])
         seq_ids = seq_group.seq_ids
         seq_ids = seq_group.seq_ids
         next_token_ids = sample_result
         next_token_ids = sample_result
         parent_ids = [0]
         parent_ids = [0]
         seq_outputs: List[SequenceOutput] = []
         seq_outputs: List[SequenceOutput] = []
         if seq_group.sampling_params.logits_processors:
         if seq_group.sampling_params.logits_processors:
-            assert (
-                len(seq_group.sampling_params.logits_processors) == 0
-            ), "Logits Processors are not supported in multi-step decoding"
-        for parent_id, next_token_id in zip(parent_ids, next_token_ids):
-            # TODO(will): support logprobs
-            # Hard coded logprob
+            assert len(seq_group.sampling_params.logits_processors) == 0, (
+                "Logits Processors are not supported in multi-step decoding")
+        for tdx, (parent_id,
+                  next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
             seq_outputs.append(
             seq_outputs.append(
-                SequenceOutput(
-                    seq_ids[parent_id],
-                    next_token_id,
-                    {next_token_id: Logprob(logprob=-1)},
-                )
-            )
-        output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
+                SequenceOutput(seq_ids[parent_id], next_token_id,
+                               (group_sample_logprobs[tdx]
+                                if logprobs_are_requested else {
+                                    next_token_id:
+                                    Logprob(logprob=float('inf'),
+                                            rank=None,
+                                            decoded_token=None)
+                                })))
+        output.outputs.append(
+            CompletionSequenceGroupOutput(
+                seq_outputs,
+                (group_prompt_logprobs if logprobs_are_requested else None)))
     assert len(output.outputs) > 0
     assert len(output.outputs) > 0

+ 2 - 1
aphrodite/task_handler/multi_step_worker.py

@@ -4,8 +4,9 @@ from typing import Dict, List, Optional, Tuple
 
 
 import torch
 import torch
 
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
 from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
 from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
 from aphrodite.task_handler.multi_step_model_runner import (
 from aphrodite.task_handler.multi_step_model_runner import (
     MultiStepModelRunner, StatefulModelInput)
     MultiStepModelRunner, StatefulModelInput)

+ 2 - 1
aphrodite/task_handler/neuron_model_runner.py

@@ -7,10 +7,11 @@ from torch import nn
 
 
 from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
 from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig)
                                      SchedulerConfig)
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import (is_pin_memory_available,
 from aphrodite.common.utils import (is_pin_memory_available,
                                     make_tensor_with_pad)
                                     make_tensor_with_pad)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.neuron import get_neuron_model
 from aphrodite.modeling.model_loader.neuron import get_neuron_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,

+ 2 - 1
aphrodite/task_handler/openvino_model_runner.py

@@ -9,7 +9,8 @@ from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
                                      ParallelConfig, SchedulerConfig)
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import SequenceGroupMetadata
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.openvino import get_model
 from aphrodite.modeling.model_loader.openvino import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,

+ 2 - 1
aphrodite/task_handler/openvino_worker.py

@@ -9,11 +9,12 @@ from aphrodite.attention import get_attn_backend
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
                                      ParallelConfig, SchedulerConfig)
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import (broadcast_tensor_dict,
 from aphrodite.distributed import (broadcast_tensor_dict,
                                    ensure_model_parallel_initialized,
                                    ensure_model_parallel_initialized,
                                    init_distributed_environment)
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.task_handler.openvino_model_runner import OpenVINOModelRunner
 from aphrodite.task_handler.openvino_model_runner import OpenVINOModelRunner
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 
 

+ 2 - 2
aphrodite/task_handler/tpu_model_runner.py

@@ -16,10 +16,10 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      SchedulerConfig)
                                      SchedulerConfig)
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        IntermediateTensors, Logprob,
                                        IntermediateTensors, Logprob,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       SequenceOutput)
+                                       SequenceGroupMetadata, SequenceOutput)
 from aphrodite.compilation.wrapper import (
 from aphrodite.compilation.wrapper import (
     TorchCompileWrapperWithCustomDispacther)
     TorchCompileWrapperWithCustomDispacther)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.task_handler.model_runner_base import (
 from aphrodite.task_handler.model_runner_base import (

+ 2 - 1
aphrodite/task_handler/worker.py

@@ -13,7 +13,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
                                      SpeculativeConfig)
 from aphrodite.common.sequence import (ExecuteModelRequest,
 from aphrodite.common.sequence import (ExecuteModelRequest,
-                                       IntermediateTensors, SamplerOutput,
+                                       IntermediateTensors,
                                        SequenceGroupMetadata,
                                        SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta)
                                        SequenceGroupMetadataDelta)
 from aphrodite.distributed import (ensure_model_parallel_initialized,
 from aphrodite.distributed import (ensure_model_parallel_initialized,
@@ -22,6 +22,7 @@ from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    set_custom_all_reduce)
                                    set_custom_all_reduce)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.platforms import current_platform
 from aphrodite.platforms import current_platform
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest

+ 2 - 2
aphrodite/task_handler/worker_base.py

@@ -7,13 +7,13 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
 import torch
 import torch
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite.common.sequence import (ExecuteModelRequest,
-                                       IntermediateTensors, SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.common.utils import (enable_trace_function_call_for_thread,
 from aphrodite.common.utils import (enable_trace_function_call_for_thread,
                                     update_environment_variables)
                                     update_environment_variables)
 from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
 from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
                                    get_tp_group)
                                    get_tp_group)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.platforms import current_platform
 from aphrodite.platforms import current_platform
 from aphrodite.task_handler.model_runner_base import (BroadcastableModelInput,
 from aphrodite.task_handler.model_runner_base import (BroadcastableModelInput,
                                                       ModelRunnerBase,
                                                       ModelRunnerBase,

+ 2 - 1
aphrodite/task_handler/xpu_model_runner.py

@@ -13,11 +13,12 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
 from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
 from aphrodite.distributed import get_pp_group
 from aphrodite.distributed import get_pp_group
 from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs, MultiModalRegistry)
                                   MultiModalInputs, MultiModalRegistry)

Деякі файли не було показано, через те що забагато файлів було змінено