Jelajahi Sumber

core: support logprobs with multi-step scheduling (#963)

* deferred sampler results

* fix imports and implement within multistep worker

* update tests

* fix test

* fix sequence test

* fix unrelated gguf ruff issue
AlpinDale 2 bulan lalu
induk
melakukan
0dfa6b60ec
100 mengubah file dengan 637 tambahan dan 320 penghapusan
  1. 0 66
      aphrodite/common/sequence.py
  2. 4 3
      aphrodite/engine/aphrodite_engine.py
  3. 2 1
      aphrodite/engine/async_aphrodite.py
  4. 11 3
      aphrodite/engine/output_processor/multi_step.py
  5. 44 17
      aphrodite/engine/output_processor/single_step.py
  6. 2 2
      aphrodite/engine/output_processor/util.py
  7. 1 1
      aphrodite/engine/protocol.py
  8. 2 1
      aphrodite/executor/cpu_executor.py
  9. 2 1
      aphrodite/executor/distributed_gpu_executor.py
  10. 2 1
      aphrodite/executor/executor_base.py
  11. 2 2
      aphrodite/executor/gpu_executor.py
  12. 2 1
      aphrodite/executor/multiproc_gpu_executor.py
  13. 2 1
      aphrodite/executor/neuron_executor.py
  14. 2 1
      aphrodite/executor/openvino_executor.py
  15. 2 1
      aphrodite/executor/ray_gpu_executor.py
  16. 2 1
      aphrodite/executor/ray_tpu_executor.py
  17. 2 1
      aphrodite/executor/tpu_executor.py
  18. 2 2
      aphrodite/executor/xpu_executor.py
  19. 247 44
      aphrodite/modeling/layers/sampler.py
  20. 1 2
      aphrodite/modeling/model_loader/neuron.py
  21. 1 2
      aphrodite/modeling/model_loader/openvino.py
  22. 2 2
      aphrodite/modeling/models/arctic.py
  23. 2 2
      aphrodite/modeling/models/baichuan.py
  24. 2 2
      aphrodite/modeling/models/bart.py
  25. 2 3
      aphrodite/modeling/models/blip2.py
  26. 2 2
      aphrodite/modeling/models/bloom.py
  27. 2 3
      aphrodite/modeling/models/chameleon.py
  28. 2 2
      aphrodite/modeling/models/chatglm.py
  29. 2 2
      aphrodite/modeling/models/commandr.py
  30. 2 2
      aphrodite/modeling/models/dbrx.py
  31. 2 2
      aphrodite/modeling/models/deepseek.py
  32. 2 2
      aphrodite/modeling/models/deepseek_v2.py
  33. 2 1
      aphrodite/modeling/models/eagle.py
  34. 2 2
      aphrodite/modeling/models/exaone.py
  35. 2 2
      aphrodite/modeling/models/falcon.py
  36. 2 2
      aphrodite/modeling/models/fuyu.py
  37. 2 2
      aphrodite/modeling/models/gemma.py
  38. 2 2
      aphrodite/modeling/models/gemma2.py
  39. 2 2
      aphrodite/modeling/models/gpt2.py
  40. 2 2
      aphrodite/modeling/models/gpt_bigcode.py
  41. 2 2
      aphrodite/modeling/models/gpt_j.py
  42. 2 2
      aphrodite/modeling/models/gpt_neox.py
  43. 2 2
      aphrodite/modeling/models/internlm2.py
  44. 2 1
      aphrodite/modeling/models/internvl.py
  45. 2 2
      aphrodite/modeling/models/jais.py
  46. 2 2
      aphrodite/modeling/models/jamba.py
  47. 2 2
      aphrodite/modeling/models/llama.py
  48. 2 1
      aphrodite/modeling/models/llava.py
  49. 2 1
      aphrodite/modeling/models/llava_next.py
  50. 2 2
      aphrodite/modeling/models/mamba.py
  51. 1 1
      aphrodite/modeling/models/medusa.py
  52. 2 2
      aphrodite/modeling/models/minicpm.py
  53. 2 3
      aphrodite/modeling/models/minicpmv.py
  54. 2 2
      aphrodite/modeling/models/mixtral.py
  55. 2 2
      aphrodite/modeling/models/mixtral_quant.py
  56. 1 2
      aphrodite/modeling/models/mlp_speculator.py
  57. 2 2
      aphrodite/modeling/models/mpt.py
  58. 2 2
      aphrodite/modeling/models/nemotron.py
  59. 2 2
      aphrodite/modeling/models/olmo.py
  60. 2 2
      aphrodite/modeling/models/olmoe.py
  61. 2 2
      aphrodite/modeling/models/opt.py
  62. 2 2
      aphrodite/modeling/models/orion.py
  63. 2 2
      aphrodite/modeling/models/paligemma.py
  64. 2 2
      aphrodite/modeling/models/persimmon.py
  65. 2 2
      aphrodite/modeling/models/phi.py
  66. 2 2
      aphrodite/modeling/models/phi3_small.py
  67. 2 2
      aphrodite/modeling/models/phi3v.py
  68. 2 2
      aphrodite/modeling/models/qwen.py
  69. 2 2
      aphrodite/modeling/models/qwen2.py
  70. 2 2
      aphrodite/modeling/models/qwen2_moe.py
  71. 2 2
      aphrodite/modeling/models/solar.py
  72. 2 2
      aphrodite/modeling/models/stablelm.py
  73. 2 2
      aphrodite/modeling/models/starcoder2.py
  74. 2 1
      aphrodite/modeling/models/ultravox.py
  75. 2 2
      aphrodite/modeling/models/xverse.py
  76. 2 1
      aphrodite/quantization/gguf.py
  77. 3 3
      aphrodite/spec_decode/batch_expansion.py
  78. 2 2
      aphrodite/spec_decode/draft_model_runner.py
  79. 2 1
      aphrodite/spec_decode/medusa_worker.py
  80. 2 1
      aphrodite/spec_decode/mlp_speculator_worker.py
  81. 2 2
      aphrodite/spec_decode/multi_step_worker.py
  82. 2 1
      aphrodite/spec_decode/ngram_worker.py
  83. 2 1
      aphrodite/spec_decode/proposer_worker_base.py
  84. 2 1
      aphrodite/spec_decode/smaller_tp_proposer_worker.py
  85. 2 2
      aphrodite/spec_decode/spec_decode_worker.py
  86. 2 1
      aphrodite/spec_decode/top1_proposer.py
  87. 2 2
      aphrodite/spec_decode/util.py
  88. 2 1
      aphrodite/task_handler/cpu_model_runner.py
  89. 2 1
      aphrodite/task_handler/enc_dec_model_runner.py
  90. 2 1
      aphrodite/task_handler/model_runner.py
  91. 2 1
      aphrodite/task_handler/model_runner_base.py
  92. 147 29
      aphrodite/task_handler/multi_step_model_runner.py
  93. 2 1
      aphrodite/task_handler/multi_step_worker.py
  94. 2 1
      aphrodite/task_handler/neuron_model_runner.py
  95. 2 1
      aphrodite/task_handler/openvino_model_runner.py
  96. 2 1
      aphrodite/task_handler/openvino_worker.py
  97. 2 2
      aphrodite/task_handler/tpu_model_runner.py
  98. 2 1
      aphrodite/task_handler/worker.py
  99. 2 2
      aphrodite/task_handler/worker_base.py
  100. 2 1
      aphrodite/task_handler/xpu_model_runner.py

+ 0 - 66
aphrodite/common/sequence.py

@@ -1046,72 +1046,6 @@ class IntermediateTensors(
         return f"IntermediateTensors(tensors={self.tensors})"
 
 
-class SamplerOutput(
-        msgspec.Struct,
-        omit_defaults=True,  # type: ignore[call-arg]
-        array_like=True):  # type: ignore[call-arg]
-    """For each sequence group, we generate a list of SequenceOutput object,
-    each of which contains one possible candidate for the next token.
-
-    This data structure implements methods, so it can be used like a list, but
-    also has optional fields for device tensors.
-    """
-
-    outputs: List[CompletionSequenceGroupOutput]
-
-    # On-device tensor containing probabilities of each token.
-    sampled_token_probs: Optional[torch.Tensor] = None
-
-    # On-device tensor containing the logprobs of each token.
-    logprobs: Optional["torch.Tensor"] = None
-
-    # On-device tensor containing the sampled token ids.
-    sampled_token_ids: Optional[torch.Tensor] = None
-    # CPU tensor containing the sampled token ids. Used during multi-step to
-    # return the sampled token ids from last rank to AsyncLLMEngine to be
-    # 'broadcasted' to all other PP ranks for next step.
-    sampled_token_ids_cpu: Optional[torch.Tensor] = None
-
-    # Spec decode metrics populated by workers.
-    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
-
-    # Optional last hidden states from the model.
-    hidden_states: Optional[torch.Tensor] = None
-
-    # Optional prefill hidden states from the model
-    # (used for models like EAGLE).
-    prefill_hidden_states: Optional[torch.Tensor] = None
-
-    # Time taken in the forward pass for this across all workers
-    model_forward_time: Optional[float] = None
-
-    def __getitem__(self, idx: int):
-        return self.outputs[idx]
-
-    def __setitem__(self, idx: int, value):
-        self.outputs[idx] = value
-
-    def __len__(self):
-        return len(self.outputs)
-
-    def __eq__(self, other: object):
-        return isinstance(other,
-                          self.__class__) and self.outputs == other.outputs
-
-    def __repr__(self) -> str:
-        """Show the shape of a tensor instead of its values to reduce noise.
-        """
-        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
-                                    else self.sampled_token_probs.shape)
-        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
-                                  self.sampled_token_ids.shape)
-        return (
-            f"SamplerOutput(outputs={self.outputs}, "
-            f"sampled_token_probs={sampled_token_probs_repr}, "
-            f"sampled_token_ids={sampled_token_ids_repr}, "
-            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
-
-
 class PoolerOutput(
         msgspec.Struct,
         omit_defaults=True,  # type: ignore[call-arg]

+ 4 - 3
aphrodite/engine/aphrodite_engine.py

@@ -24,9 +24,9 @@ from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
-                                       ExecuteModelRequest, SamplerOutput,
-                                       Sequence, SequenceGroup,
-                                       SequenceGroupMetadata, SequenceStatus)
+                                       ExecuteModelRequest, Sequence,
+                                       SequenceGroup, SequenceGroupMetadata,
+                                       SequenceStatus)
 from aphrodite.common.utils import Counter, Device
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics_types import StatLoggerBase, Stats
@@ -42,6 +42,7 @@ from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
                               SingletonPromptInputs)
 from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.multimodal import MultiModalDataDict
 from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
                                             SchedulerOutputs)

+ 2 - 1
aphrodite/engine/async_aphrodite.py

@@ -14,7 +14,7 @@ from aphrodite.common.config import (DecodingConfig, EngineConfig, LoRAConfig,
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import print_warning_once
 from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
                                                DecoderPromptComponents,
@@ -29,6 +29,7 @@ from aphrodite.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                               SingletonPromptInputs)
 from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.transformers_utils.tokenizer import AnyTokenizer

+ 11 - 3
aphrodite/engine/output_processor/multi_step.py

@@ -11,6 +11,8 @@ from aphrodite.common.sequence import (Sequence, SequenceGroup,
 from aphrodite.common.utils import Counter
 from aphrodite.engine.output_processor.interfaces import (
     SequenceGroupOutputProcessor)
+from aphrodite.engine.output_processor.single_step import (
+    single_step_process_prompt_logprob)
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
@@ -46,9 +48,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
 
     def process_prompt_logprob(self, seq_group: SequenceGroup,
                                outputs: List[SequenceGroupOutput]) -> None:
-        # TODO: Prompt logprob currently not implemented in multi step
-        # workers.
-        self._log_prompt_logprob_unsupported_warning_once()
+        """Process prompt logprobs associated with each step of a multi-step-
+        scheduled computation.
+        Args:
+          seq_group: the outputs are associated with this :class:`SequenceGroup`
+          outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
+        """
+        for output in outputs:
+            # Concatenate single-step prompt logprob processing results.
+            single_step_process_prompt_logprob(self, seq_group, output)
 
     @staticmethod
     @functools.lru_cache()

+ 44 - 17
aphrodite/engine/output_processor/single_step.py

@@ -13,6 +13,42 @@ from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
 
 
+def single_step_process_prompt_logprob(
+        sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
+        output: SequenceGroupOutput) -> None:
+    """Process prompt logprobs associated with the :class:`SequenceGroupOutput`
+    for a given step.
+    Do nothing if the output has no prompt logprobs.
+    Account for the fact that transformers do not compute first-token logprobs.
+
+    Args:
+      sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
+      seq_group: the output is associated with this :class:`SequenceGroup`
+      output: the :class:`SequenceGroupOutput` for a single scheduler step
+    """
+    prompt_logprobs = output.prompt_logprobs
+
+    # If this is the first (or only) "chunk" of the prefill, we need
+    # to prepend None to the list of prompt logprobs. The reason for this
+    # is that for N prompt tokens, the Sampler will generate N-1 total
+    # prompt logprobs during prefill since the token at idx 0 will not
+    # have a logprob associated with it.
+    if prompt_logprobs is not None:
+        if not seq_group.prompt_logprobs:
+            prompt_logprobs = [None] + prompt_logprobs
+            seq_group.prompt_logprobs = []
+
+        assert hasattr(sg_output_proc, 'detokenizer')
+        if (seq_group.sampling_params.detokenize
+                and sg_output_proc.detokenizer):
+            sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
+                seq_group,
+                prompt_logprobs,
+                position_offset=len(seq_group.prompt_logprobs))
+
+        seq_group.prompt_logprobs.extend(prompt_logprobs)
+
+
 class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
     """SequenceGroupOutputProcessor which handles "output processing" logic,
     which happens after the model returns generated token ids and before
@@ -57,25 +93,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
 
     def process_prompt_logprob(self, seq_group: SequenceGroup,
                                outputs: List[SequenceGroupOutput]) -> None:
+        """Process prompt logprobs associated with one step of a single-step-
+        scheduled computation.
+        
+        Args:
+          seq_group: the output is associated with this :class:`SequenceGroup`
+          output: the :class:`SequenceGroupOutput` for a single scheduler step
+        """
         assert len(outputs) == 1, ("Single step should only has 1 output.")
         output = outputs[0]
-        prompt_logprobs = output.prompt_logprobs
-
-        # If this is the first (or only) "chunk" of the prefill, we need
-        # to prepend None to the list of prompt logprobs. The reason for this
-        # is that for N prompt tokens, the Sampler will generate N-1 total
-        # prompt logprobs during prefill since the token at idx 0 will not
-        # have a logprob associated with it.
-        if prompt_logprobs is not None:
-            if not seq_group.prompt_logprobs:
-                prompt_logprobs = [None] + prompt_logprobs
-                seq_group.prompt_logprobs = []
-            if seq_group.sampling_params.detokenize and self.detokenizer:
-                self.detokenizer.decode_prompt_logprobs_inplace(
-                    seq_group,
-                    prompt_logprobs,
-                    position_offset=len(seq_group.prompt_logprobs))
-            seq_group.prompt_logprobs.extend(prompt_logprobs)
+        single_step_process_prompt_logprob(self, seq_group, output)
 
     def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
                                         outputs: SequenceGroupOutput,

+ 2 - 2
aphrodite/engine/output_processor/util.py

@@ -2,8 +2,8 @@ from typing import List
 from typing import Sequence as GenericSequence
 from typing import Union
 
-from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
-                                       SequenceGroupOutput)
+from aphrodite.common.sequence import PoolerOutput, SequenceGroupOutput
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 def create_output_by_sequence_group(

+ 1 - 1
aphrodite/engine/protocol.py

@@ -6,9 +6,9 @@ from aphrodite.common.config import DecodingConfig, ModelConfig
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.inputs.data import PromptInputs
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 

+ 2 - 1
aphrodite/executor/cpu_executor.py

@@ -7,7 +7,7 @@ from loguru import logger
 
 import aphrodite.common.envs as envs
 from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
                                     get_distributed_init_method, get_open_port,
                                     make_async)
@@ -16,6 +16,7 @@ from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
                                                        ResultHandler,
                                                        WorkerMonitor)
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.task_handler.worker_base import WorkerWrapperBase
 

+ 2 - 1
aphrodite/executor/distributed_gpu_executor.py

@@ -4,10 +4,11 @@ from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
 
 from loguru import logger
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 class DistributedGPUExecutor(GPUExecutor):

+ 2 - 1
aphrodite/executor/executor_base.py

@@ -5,8 +5,9 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 

+ 2 - 2
aphrodite/executor/gpu_executor.py

@@ -2,12 +2,12 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
 
 from loguru import logger
 
-from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
-                                       SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase
 

+ 2 - 1
aphrodite/executor/multiproc_gpu_executor.py

@@ -9,7 +9,7 @@ from typing import Any, List, Optional
 import torch
 from loguru import logger
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (_run_task_with_lock,
                                     cuda_device_count_stateless,
                                     get_aphrodite_instance_id,
@@ -21,6 +21,7 @@ from aphrodite.executor.gpu_executor import create_worker
 from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
                                                        ResultHandler,
                                                        WorkerMonitor)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.triton_utils import maybe_set_triton_cache_manager
 
 

+ 2 - 1
aphrodite/executor/neuron_executor.py

@@ -1,10 +1,11 @@
 from typing import List, Set, Tuple
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 class NeuronExecutor(ExecutorBase):

+ 2 - 1
aphrodite/executor/openvino_executor.py

@@ -7,11 +7,12 @@ from loguru import logger
 
 import aphrodite.common.envs as envs
 from aphrodite.common.config import CacheConfig, ModelConfig
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (GiB_bytes, get_distributed_init_method,
                                     get_ip, get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 APHRODITE_OPENVINO_KVCACHE_SPACE = envs.APHRODITE_OPENVINO_KVCACHE_SPACE
 APHRODITE_OPENVINO_CPU_KV_CACHE_PRECISION = (

+ 2 - 1
aphrodite/executor/ray_gpu_executor.py

@@ -8,7 +8,7 @@ import msgspec
 from loguru import logger
 
 import aphrodite.common.envs as envs
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (_run_task_with_lock,
                                     get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
@@ -17,6 +17,7 @@ from aphrodite.executor.distributed_gpu_executor import (  # yapf: disable
     DistributedGPUExecutor, DistributedGPUExecutorAsync)
 from aphrodite.executor.msgspec_utils import encode_hook
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

+ 2 - 1
aphrodite/executor/ray_tpu_executor.py

@@ -8,13 +8,14 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
 from loguru import logger
 
 import aphrodite.common.envs as envs
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
 from aphrodite.executor.tpu_executor import TPUExecutor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

+ 2 - 1
aphrodite/executor/tpu_executor.py

@@ -3,11 +3,12 @@ from typing import Any, Dict, List, Optional, Set, Tuple
 import torch
 from loguru import logger
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 
 class TPUExecutor(ExecutorBase):

+ 2 - 2
aphrodite/executor/xpu_executor.py

@@ -7,11 +7,11 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
-from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
-                                       SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, PoolerOutput
 from aphrodite.common.utils import make_async
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.task_handler.worker_base import WorkerBase
 
 

+ 247 - 44
aphrodite/modeling/layers/sampler.py

@@ -1,10 +1,12 @@
 """A layer that samples the next tokens from the model's outputs."""
 import itertools
 import warnings
+from dataclasses import dataclass
 from enum import IntEnum
 from math import inf
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
+import msgspec
 import torch
 import torch.nn as nn
 from loguru import logger
@@ -14,7 +16,8 @@ import aphrodite.common.envs as envs
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        PromptLogprobs, SampleLogprobs,
-                                       SamplerOutput, SequenceOutput)
+                                       SequenceOutput)
+from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 from aphrodite.triton_utils import HAS_TRITON
 
 if HAS_TRITON:
@@ -27,6 +30,115 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
 # (num_token_ids, num_parent_ids) per sequence group.
 SampleResultType = List[Tuple[List[int], List[int]]]
 
+# Types of temporary data structures used for
+# computing sample_result
+SampleMetadataType = Dict[SamplingType, Tuple[List[int],
+                                              List[SequenceGroupToSample]]]
+MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
+SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
+
+
+# Encapsulates temporary data structures for computing
+# sample_result.
+#
+# * For multi-step scheduling: must be returned
+#   by `Sampler.forward()` and used later to compute the pythonized
+#   sample_result
+#
+# * For single-step scheduling: consumed immediately
+#   inside `Sampler.forward()` to compute pythonized sample_result.
+@dataclass
+class SampleResultArgsType:
+    sample_metadata: SampleMetadataType
+    multinomial_samples: MultinomialSamplesType
+    sample_results_dict: SampleResultsDictType
+    sampling_metadata: SamplingMetadata
+    greedy_samples: Optional[torch.Tensor]
+    beam_search_logprobs: Optional[torch.Tensor]
+
+
+# Union of non-deferred (single-step scheduling)
+# vs deferred (multi-step scheduling)
+# sample result types
+MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
+
+# Abbreviation of the _sample() return type
+SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
+
+
+class SamplerOutput(
+        msgspec.Struct,
+        omit_defaults=True,  # type: ignore[call-arg]
+        array_like=True):  # type: ignore[call-arg]
+    """For each sequence group, we generate a list of SequenceOutput object,
+    each of which contains one possible candidate for the next token.
+    This data structure implements methods, so it can be used like a list, but
+    also has optional fields for device tensors.
+    """
+
+    outputs: List[CompletionSequenceGroupOutput]
+
+    # On-device tensor containing probabilities of each token.
+    sampled_token_probs: Optional[torch.Tensor] = None
+
+    # On-device tensor containing the logprobs of each token.
+    logprobs: Optional["torch.Tensor"] = None
+
+    # Holds either (1) the pythonized sampler result (single-step scheduling)
+    # or (2) what will be arguments for later deferred pythonization of the
+    # sampler result (muliti-step scheduling)
+    deferred_sample_results_args: Optional[SampleResultArgsType] = None
+
+    # On-device tensor containing the sampled token ids.
+    sampled_token_ids: Optional[torch.Tensor] = None
+    # CPU tensor containing the sampled token ids. Used during multi-step to
+    # return the sampled token ids from last rank to AsyncLLMEngine to be
+    # 'broadcasted' to all other PP ranks for next step.
+    sampled_token_ids_cpu: Optional[torch.Tensor] = None
+
+    # Spec decode metrics populated by workers.
+    spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
+
+    # Optional last hidden states from the model.
+    hidden_states: Optional[torch.Tensor] = None
+
+    # Optional prefill hidden states from the model
+    # (used for models like EAGLE).
+    prefill_hidden_states: Optional[torch.Tensor] = None
+
+    # Time taken in the forward pass for this across all workers
+    model_forward_time: Optional[float] = None
+
+    # Time taken in the model execute function. This will include model forward,
+    # block/sync across workers, cpu-gpu sync time and sampling time.
+    model_execute_time: Optional[float] = None
+
+    def __getitem__(self, idx: int):
+        return self.outputs[idx]
+
+    def __setitem__(self, idx: int, value):
+        self.outputs[idx] = value
+
+    def __len__(self):
+        return len(self.outputs)
+
+    def __eq__(self, other: object):
+        return isinstance(other,
+                          self.__class__) and self.outputs == other.outputs
+
+    def __repr__(self) -> str:
+        """Show the shape of a tensor instead of its values to reduce noise.
+        """
+        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
+                                    else self.sampled_token_probs.shape)
+        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
+                                  self.sampled_token_ids.shape)
+        return (
+            f"SamplerOutput(outputs={self.outputs}, "
+            f"sampled_token_probs={sampled_token_probs_repr}, "
+            f"sampled_token_ids={sampled_token_ids_repr}, "
+            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
+
 # There isn't a "safe" temperature range for fp16 logits.
 # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
 # that this temperature well-uses the fp16 space after the logits are offset.
@@ -135,6 +247,18 @@ class Sampler(nn.Module):
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
         """
+        Single-step scheduling:
+        * Perform GPU-side sampling computation & compute
+          GPU-side logprobs tensor
+        * Pythonize sampling result & logprobs tensor
+        Multi-step scheduling:
+        * Perform GPU-side sampling computation & compute
+          GPU-side logprobs tensor
+        * Defer Pythonization of sampling result & logprobs
+          tensor
+        * Encapsulate arguments required for deferred Pythonization
+          in the :class:`SamplerOutput` structure
+
         Args:
             logits: (num_tokens, vocab_size).
             sampling_metadata: Metadata for sampling.
@@ -425,7 +549,7 @@ class Sampler(nn.Module):
         logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
 
         # Sample the next tokens.
-        sample_results, maybe_sampled_tokens_tensor = _sample(
+        maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
             probs,
             logprobs,
             sampling_metadata,
@@ -435,20 +559,28 @@ class Sampler(nn.Module):
         )
 
         if self.include_gpu_probs_tensor:
+            # Since we will defer sampler result Pythonization,
+            # preserve GPU-side tensors in support of later
+            # deferred pythonization of logprobs
             assert maybe_sampled_tokens_tensor is not None
             on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
         else:
+            # Since Pythonization has already happened, don't preserve
+            # GPU-side tensors.
             on_device_tensors = None
 
         # Get the logprobs query results.
         prompt_logprobs = None
         sample_logprobs = None
         if not sampling_metadata.skip_sampler_cpu_output:
-            prompt_logprobs, sample_logprobs = _get_logprobs(
-                logprobs, sampling_metadata, sample_results)
+            # Pythonize logprobs now (GPU -> CPU); do not defer.
+            assert not isinstance(maybe_deferred_sample_results,
+                                  SampleResultArgsType)
+            prompt_logprobs, sample_logprobs = get_logprobs(
+                logprobs, sampling_metadata, maybe_deferred_sample_results)
 
         return _build_sampler_output(
-            sample_results,
+            maybe_deferred_sample_results,
             sampling_metadata,
             prompt_logprobs,
             sample_logprobs,
@@ -1205,6 +1337,57 @@ def _top_k_top_p_multinomial_with_kernels(
     return batch_next_token_ids.view(-1, num_samples)
 
 
+def get_pythonized_sample_results(
+        sample_result_args: SampleResultArgsType) -> SampleResultType:
+    """This function consumes GPU-side sampler results and computes
+    Pythonized CPU-side sampler results (GPU -> CPU sync.)
+    Single-step scheduling: this function is invoked at sampling-time
+    for immediate Pythonization.
+    Multi-step scheduling: Pythonization is deferred until after multiple
+    GPU-side steps have been completed.
+
+    Args:
+      sample_result_args: GPU-side inputs to the Pythonization process
+    Returns:
+      Pythonized sampler results
+    """
+
+    (
+        sample_metadata,
+        sampling_metadata,
+        greedy_samples,
+        multinomial_samples,
+        beam_search_logprobs,
+        sample_results_dict,
+    ) = (
+        sample_result_args.sample_metadata,
+        sample_result_args.sampling_metadata,
+        sample_result_args.greedy_samples,
+        sample_result_args.multinomial_samples,
+        sample_result_args.beam_search_logprobs,
+        sample_result_args.sample_results_dict,
+    )
+
+    for sampling_type in SamplingType:
+        if sampling_type not in sample_metadata:
+            continue
+        (seq_group_id, seq_groups) = sample_metadata[sampling_type]
+        if sampling_type == SamplingType.GREEDY:
+            sample_results = _greedy_sample(seq_groups, greedy_samples)
+        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
+            sample_results = _random_sample(seq_groups,
+                                            multinomial_samples[sampling_type])
+        elif sampling_type == SamplingType.BEAM:
+            sample_results = _beam_search_sample(seq_groups,
+                                                 beam_search_logprobs)
+        sample_results_dict.update(zip(seq_group_id, sample_results))
+
+    return [
+        sample_results_dict.get(i, ([], []))
+        for i in range(len(sampling_metadata.seq_groups))
+    ]
+
+
 def _sample_with_torch(
     probs: torch.Tensor,
     logprobs: torch.Tensor,
@@ -1212,7 +1395,18 @@ def _sample_with_torch(
     sampling_tensors: SamplingTensors,
     include_gpu_probs_tensor: bool,
     modify_greedy_probs: bool,
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+) -> SampleReturnType:
+    """Torch-oriented _sample() implementation.
+
+    Single-step scheduling: 
+    * Perform GPU-side sampling computation
+    * Immediately Pythonize sampling result
+
+    Multi-step scheduling:
+    * Perform GPU-side sampling computation
+    * Defer Pythonization & preserve GPU-side
+      tensors required for Pythonization
+    """
     categorized_seq_group_ids = {t: [] for t in SamplingType}
     categorized_sample_indices = sampling_metadata.categorized_sample_indices
     for i, seq_group in enumerate(sampling_metadata.seq_groups):
@@ -1220,9 +1414,11 @@ def _sample_with_torch(
         sampling_type = sampling_params.sampling_type
         categorized_seq_group_ids[sampling_type].append(i)
 
-    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
-    sample_metadata = {}
-    multinomial_samples = {}
+    sample_results_dict: SampleResultsDictType = {}
+    sample_metadata: SampleMetadataType = {}
+    multinomial_samples: MultinomialSamplesType = {}
+    greedy_samples: Optional[torch.Tensor] = None
+    beam_search_logprobs: Optional[torch.Tensor] = None
     # Create output tensor for sampled token ids.
     if include_gpu_probs_tensor:
         sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
@@ -1293,32 +1489,29 @@ def _sample_with_torch(
         else:
             raise ValueError(f"Unsupported sampling type: {sampling_type}")
 
-    # GPU<->CPU sync happens in the loop below.
-    # This also converts the sample output to Python objects.
+    # Encapsulate arguments for computing Pythonized sampler
+    # results, whether deferred or otherwise.
+    maybe_deferred_args = SampleResultArgsType(
+        sampling_metadata=sampling_metadata,
+        sample_metadata=sample_metadata,
+        multinomial_samples=multinomial_samples,
+        greedy_samples=greedy_samples,
+        beam_search_logprobs=beam_search_logprobs,
+        sample_results_dict=sample_results_dict)
+
     if not sampling_metadata.skip_sampler_cpu_output:
-        for sampling_type in SamplingType:
-            if sampling_type not in sample_metadata:
-                continue
-            (seq_group_id, seq_groups) = sample_metadata[sampling_type]
-            if sampling_type == SamplingType.GREEDY:
-                sample_results = _greedy_sample(seq_groups, greedy_samples)
-            elif sampling_type in (SamplingType.RANDOM,
-                                   SamplingType.RANDOM_SEED):
-                sample_results = _random_sample(
-                    seq_groups, multinomial_samples[sampling_type])
-            elif sampling_type == SamplingType.BEAM:
-                sample_results = _beam_search_sample(seq_groups,
-                                                     beam_search_logprobs)
-            sample_results_dict.update(zip(seq_group_id, sample_results))
-
-        sample_results = [
-            sample_results_dict.get(i, ([], []))
-            for i in range(len(sampling_metadata.seq_groups))
-        ]
+        # GPU<->CPU sync happens here.
+        # This also converts the sampler output to a Python object.
+        # Return Pythonized sampler result & sampled token ids
+        return get_pythonized_sample_results(
+            maybe_deferred_args), sampled_token_ids_tensor
     else:
-        sample_results = []
-
-    return sample_results, sampled_token_ids_tensor
+        # Defer sampler result Pythonization; return deferred
+        # Pythonization args & sampled token ids
+        return (
+            maybe_deferred_args,
+            sampled_token_ids_tensor,
+        )
 
 
 def _sample_with_triton_kernel(
@@ -1396,10 +1589,13 @@ def _sample_with_triton_kernel(
 
 
 def _sample(
-    probs: torch.Tensor, logprobs: torch.Tensor,
-    sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
-    include_gpu_probs_tensor: bool, modify_greedy_probs: bool
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+    probs: torch.Tensor,
+    logprobs: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+    sampling_tensors: SamplingTensors,
+    include_gpu_probs_tensor: bool,
+    modify_greedy_probs: bool,
+) -> SampleReturnType:
     """
     Args:
         probs: (num_query_tokens_in_batch, num_vocab)
@@ -1441,7 +1637,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
     return (x > vals[:, None]).long().sum(1).add_(1)
 
 
-def _get_logprobs(
+def get_logprobs(
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
     sample_results: List[Tuple[List[int], List[int]]],
@@ -1755,7 +1951,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
 
 
 def _build_sampler_output(
-    sample_results: SampleResultType,
+    maybe_deferred_sample_results: MaybeDeferredSampleResultType,
     sampling_metadata: SamplingMetadata,
     prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
     sample_logprobs: Optional[List[SampleLogprobs]],
@@ -1771,14 +1967,21 @@ def _build_sampler_output(
             speculative decoding rejection sampling.
     """
     sampler_output: List[CompletionSequenceGroupOutput] = []
-    if not skip_sampler_cpu_output:
+
+    if skip_sampler_cpu_output:
+        assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
+        deferred_sample_results_args = maybe_deferred_sample_results
+    else:
         assert prompt_logprobs is not None
         assert sample_logprobs is not None
+        assert not isinstance(maybe_deferred_sample_results,
+                              SampleResultArgsType)
+        deferred_sample_results_args = None
 
         for (seq_group, sample_result, group_prompt_logprobs,
              group_sample_logprobs) in zip(sampling_metadata.seq_groups,
-                                           sample_results, prompt_logprobs,
-                                           sample_logprobs):
+                                           maybe_deferred_sample_results,
+                                           prompt_logprobs, sample_logprobs):
             seq_ids = seq_group.seq_ids
             next_token_ids, parent_ids = sample_result
             seq_outputs: List[SequenceOutput] = []
@@ -1802,7 +2005,7 @@ def _build_sampler_output(
         sampled_token_probs=sampled_token_probs,
         sampled_token_ids=sampled_token_ids,
         logprobs=logprobs_tensor,
-    )
+        deferred_sample_results_args=deferred_sample_results_args)
 
 
 def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:

+ 1 - 2
aphrodite/modeling/model_loader/neuron.py

@@ -10,9 +10,8 @@ from transformers import PretrainedConfig
 
 from aphrodite.common.config import (ModelConfig, ParallelConfig,
                                      SchedulerConfig)
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 
 TORCH_DTYPE_TO_NEURON_AMP = {

+ 1 - 2
aphrodite/modeling/model_loader/openvino.py

@@ -13,10 +13,9 @@ from torch import nn
 import aphrodite.common.envs as envs
 from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
 from aphrodite.common.config import DeviceConfig, ModelConfig
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
                                                         _prune_hidden_states)
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 
 APHRODITE_OPENVINO_ENABLE_QUANTIZED_WEIGHTS = (

+ 2 - 2
aphrodite/modeling/models/arctic.py

@@ -7,7 +7,7 @@ from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -20,7 +20,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/baichuan.py

@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/bart.py

@@ -25,14 +25,14 @@ from transformers import BartConfig
 
 from aphrodite.attention import Attention, AttentionMetadata, AttentionType
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 3
aphrodite/modeling/models/blip2.py

@@ -9,13 +9,12 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.opt import OPTModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata

+ 2 - 2
aphrodite/modeling/models/bloom.py

@@ -25,7 +25,7 @@ from transformers import BloomConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 3
aphrodite/modeling/models/chameleon.py

@@ -11,8 +11,7 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.common.utils import print_warning_once
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.distributed import get_tensor_model_parallel_world_size
@@ -24,7 +23,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/chatglm.py

@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -19,7 +19,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/commandr.py

@@ -29,7 +29,7 @@ from transformers import CohereConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/dbrx.py

@@ -6,7 +6,7 @@ import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -16,7 +16,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/deepseek.py

@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/deepseek_v2.py

@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/modeling/models/eagle.py

@@ -4,8 +4,9 @@ import torch
 import torch.nn as nn
 
 from aphrodite.attention.backends.abstract import AttentionMetadata
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/exaone.py

@@ -30,7 +30,7 @@ from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/falcon.py

@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -38,7 +38,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/fuyu.py

@@ -27,11 +27,11 @@ from transformers import FuyuConfig, FuyuImageProcessor
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
 from aphrodite.modeling.sampling_metadata import SamplingMetadata

+ 2 - 2
aphrodite/modeling/models/gemma.py

@@ -24,7 +24,7 @@ from transformers import GemmaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gemma2.py

@@ -24,7 +24,7 @@ from transformers import Gemma2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt2.py

@@ -25,14 +25,14 @@ from transformers import GPT2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt_bigcode.py

@@ -26,14 +26,14 @@ from transformers import GPTBigCodeConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt_j.py

@@ -24,7 +24,7 @@ from transformers import GPTJConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -32,7 +32,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/gpt_neox.py

@@ -24,7 +24,7 @@ from transformers import GPTNeoXConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -32,7 +32,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/internlm2.py

@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -16,7 +16,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/modeling/models/internvl.py

@@ -16,8 +16,9 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.intern_vit import InternVisionModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata

+ 2 - 2
aphrodite/modeling/models/jais.py

@@ -27,14 +27,14 @@ from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/jamba.py

@@ -11,7 +11,7 @@ from transformers import JambaConfig
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.layer import Attention
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 # yapf: disable
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -28,7 +28,7 @@ from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
                                              causal_conv1d_update,
                                              selective_scan_fn,
                                              selective_state_update)
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/llama.py

@@ -29,7 +29,7 @@ from transformers import LlamaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 1
aphrodite/modeling/models/llava.py

@@ -8,9 +8,10 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY

+ 2 - 1
aphrodite/modeling/models/llava_next.py

@@ -12,9 +12,10 @@ from typing_extensions import NotRequired
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_list_of
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY

+ 2 - 2
aphrodite/modeling/models/mamba.py

@@ -10,7 +10,7 @@ from transformers import MambaConfig
 
 from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -23,7 +23,7 @@ from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
     causal_conv1d_fn, causal_conv1d_update)
 from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
     selective_scan_fn, selective_state_update)
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 1 - 1
aphrodite/modeling/models/medusa.py

@@ -3,8 +3,8 @@ from typing import Iterable, List, Optional, Tuple
 import torch
 import torch.nn as nn
 
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/minicpm.py

@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -44,7 +44,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 3
aphrodite/modeling/models/minicpmv.py

@@ -39,13 +39,12 @@ from transformers.configuration_utils import PretrainedConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData)
+from aphrodite.common.sequence import IntermediateTensors, SequenceData
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ReplicatedLinear
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/mixtral.py

@@ -29,7 +29,7 @@ from transformers import MixtralConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.fused_moe import FusedMoE
@@ -39,7 +39,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -31,7 +31,7 @@ from transformers import MixtralConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -41,7 +41,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 1 - 2
aphrodite/modeling/models/mlp_speculator.py

@@ -4,9 +4,8 @@ from typing import Iterable, List, Tuple
 import torch
 import torch.nn as nn
 
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/mpt.py

@@ -8,7 +8,7 @@ import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -16,7 +16,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/nemotron.py

@@ -30,7 +30,7 @@ from transformers import NemotronConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -39,7 +39,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/olmo.py

@@ -29,7 +29,7 @@ from transformers import OlmoConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/olmoe.py

@@ -18,7 +18,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.fused_moe import FusedMoE
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -27,7 +27,7 @@ from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/opt.py

@@ -25,7 +25,7 @@ from transformers import OPTConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -33,7 +33,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/orion.py

@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -20,7 +20,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/paligemma.py

@@ -8,10 +8,10 @@ from transformers import PaliGemmaConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.gemma import GemmaModel
 from aphrodite.modeling.models.gemma2 import Gemma2Model

+ 2 - 2
aphrodite/modeling/models/persimmon.py

@@ -30,14 +30,14 @@ from transformers.activations import ReLUSquaredActivation
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/phi.py

@@ -43,7 +43,7 @@ from transformers import PhiConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -51,7 +51,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/phi3_small.py

@@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -15,7 +15,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/phi3v.py

@@ -29,11 +29,11 @@ from transformers import CLIPVisionConfig, PretrainedConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_list_of
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.clip import CLIPVisionModel

+ 2 - 2
aphrodite/modeling/models/qwen.py

@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -21,7 +21,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/qwen2.py

@@ -30,7 +30,7 @@ from transformers import Qwen2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -42,7 +42,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/qwen2_moe.py

@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size,
@@ -45,7 +45,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/solar.py

@@ -29,7 +29,7 @@ from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -41,7 +41,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import (

+ 2 - 2
aphrodite/modeling/models/stablelm.py

@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -35,7 +35,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 2
aphrodite/modeling/models/starcoder2.py

@@ -26,7 +26,7 @@ from transformers import Starcoder2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -34,7 +34,7 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/modeling/models/ultravox.py

@@ -20,12 +20,13 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       SamplerOutput, SequenceData)
+                                       SequenceData)
 from aphrodite.inputs import INPUT_REGISTRY
 from aphrodite.inputs.data import LLMInputs
 from aphrodite.inputs.registry import InputContext
 from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
 from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.modeling.models.utils import (filter_weights,

+ 2 - 2
aphrodite/modeling/models/xverse.py

@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -37,7 +37,7 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader

+ 2 - 1
aphrodite/quantization/gguf.py

@@ -125,7 +125,8 @@ class GGUFLinearMethod(LinearMethodBase):
                     offset:offset +
                     shard_size[id][0], :shard_size[id][1]].contiguous()
                 qweight_type = layer.qweight_type.shard_weight_type[id]
-                result.append(_fuse_mul_mat(x, shard_weight, qweight_type).contiguous())
+                result.append(_fuse_mul_mat(x, shard_weight,
+                                            qweight_type).contiguous())
                 offset += shard_size[id][0]
             out = torch.cat(result, dim=-1)
         else:

+ 3 - 3
aphrodite/spec_decode/batch_expansion.py

@@ -6,9 +6,9 @@ import torch
 
 from aphrodite import SamplingParams
 from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
-                                       ExecuteModelRequest, SamplerOutput,
-                                       SequenceData, SequenceGroupMetadata,
-                                       get_all_seq_ids)
+                                       ExecuteModelRequest, SequenceData,
+                                       SequenceGroupMetadata, get_all_seq_ids)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeScorer,
                                               SpeculativeScores)

+ 2 - 2
aphrodite/spec_decode/draft_model_runner.py

@@ -15,8 +15,8 @@ except ModuleNotFoundError:
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
-from aphrodite.common.sequence import (ExecuteModelRequest,
-                                       IntermediateTensors, SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.multimodal import MultiModalInputs
 from aphrodite.task_handler.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)

+ 2 - 1
aphrodite/spec_decode/medusa_worker.py

@@ -3,8 +3,9 @@ from typing import List, Optional, Set, Tuple
 
 import torch
 
-from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+from aphrodite.common.sequence import (ExecuteModelRequest,
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase

+ 2 - 1
aphrodite/spec_decode/mlp_speculator_worker.py

@@ -2,8 +2,9 @@ from typing import List, Optional, Set, Tuple
 
 import torch
 
-from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+from aphrodite.common.sequence import (ExecuteModelRequest,
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase

+ 2 - 2
aphrodite/spec_decode/multi_step_worker.py

@@ -5,8 +5,8 @@ from typing import Dict, List, Set, Tuple
 import torch
 
 from aphrodite.common.sequence import (ExecuteModelRequest, HiddenStates,
-                                       SamplerOutput, SequenceData,
-                                       SequenceGroupMetadata)
+                                       SequenceData, SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeProposer)

+ 2 - 1
aphrodite/spec_decode/ngram_worker.py

@@ -3,7 +3,8 @@ from typing import List, Optional, Set, Tuple
 
 import torch
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.top1_proposer import Top1Proposer

+ 2 - 1
aphrodite/spec_decode/proposer_worker_base.py

@@ -1,7 +1,8 @@
 from abc import ABC, abstractmethod
 from typing import List, Optional, Set, Tuple
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 

+ 2 - 1
aphrodite/spec_decode/smaller_tp_proposer_worker.py

@@ -3,10 +3,11 @@ from typing import List, Optional, Set, Tuple
 import torch
 from loguru import logger
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed.parallel_state import (get_tp_group,
                                                   init_model_parallel_group,
                                                   patch_tensor_parallel_group)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase

+ 2 - 2
aphrodite/spec_decode/spec_decode_worker.py

@@ -8,11 +8,11 @@ from loguru import logger
 from aphrodite.common.config import ParallelConfig, SpeculativeConfig
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        ExecuteModelRequest, HiddenStates,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       get_all_seq_ids,
+                                       SequenceGroupMetadata, get_all_seq_ids,
                                        get_all_seq_ids_and_request_ids)
 from aphrodite.distributed.communication_op import broadcast_tensor_dict
 from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.spec_decode_base_sampler import (
     SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
 from aphrodite.modeling.layers.typical_acceptance_sampler import (

+ 2 - 1
aphrodite/spec_decode/top1_proposer.py

@@ -2,8 +2,9 @@ from typing import List, Optional, Set, Tuple
 
 import torch
 
-from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+from aphrodite.common.sequence import (ExecuteModelRequest,
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeProposer)
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase

+ 2 - 2
aphrodite/spec_decode/util.py

@@ -5,8 +5,8 @@ from typing import Dict, List, Optional, Sequence, Tuple
 import torch
 
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       SequenceOutput)
+                                       SequenceGroupMetadata, SequenceOutput)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 
 SeqId = int
 

+ 2 - 1
aphrodite/task_handler/cpu_model_runner.py

@@ -8,9 +8,10 @@ from aphrodite.attention import AttentionMetadata, get_attn_backend
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import make_tensor_with_pad
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,

+ 2 - 1
aphrodite/task_handler/enc_dec_model_runner.py

@@ -17,10 +17,11 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
-                                       SamplerOutput, SequenceGroupMetadata)
+                                       SequenceGroupMetadata)
 from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
                                     make_tensor_with_pad)
 from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
 from aphrodite.task_handler.model_runner import (

+ 2 - 1
aphrodite/task_handler/model_runner.py

@@ -23,7 +23,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
                                     async_tensor_h2d, flatten_2d_lists, is_hip,
@@ -36,6 +36,7 @@ from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.lora.layers import LoRAMapping
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.models.interfaces import (supports_lora,

+ 2 - 1
aphrodite/task_handler/model_runner_base.py

@@ -5,8 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
 
 import torch
 
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.platforms import current_platform
 
 if TYPE_CHECKING:

+ 147 - 29
aphrodite/task_handler/multi_step_model_runner.py

@@ -1,7 +1,8 @@
 import dataclasses
 import functools
 from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
+from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
+                    Union)
 
 try:
     from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
@@ -16,9 +17,12 @@ import torch
 from aphrodite import _custom_ops as ops
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        IntermediateTensors, Logprob,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       SequenceOutput)
+                                       SequenceGroupMetadata, SequenceOutput)
 from aphrodite.distributed import get_pp_group
+from aphrodite.modeling.layers.sampler import (PromptLogprobs, SampleLogprobs,
+                                               SamplerOutput, SamplingMetadata,
+                                               get_logprobs,
+                                               get_pythonized_sample_results)
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.task_handler.model_runner import (
     GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
@@ -50,6 +54,8 @@ class ModelOutput:
     sampler_output_ready_event: torch.cuda.Event
     sampled_token_ids: Optional[torch.Tensor] = None
     pythonized: bool = False
+    # On-device tensor containing the logprobs of each token.
+    logprobs: Optional["torch.Tensor"] = None
 
     def pythonize(
         self,
@@ -85,7 +91,9 @@ class ModelOutput:
     ) -> bool:
         """
         If blocking is set, will block until the forward pass for the output is
-        ready and pythonize the output.
+        ready and pythonize the output. Upon completing Pythonization, erases
+        self.logprobs (note that a non-blocking call that is performed when
+        the sampler output is not yet ready, will not erase self.logprobs.)
         """
         assert self.sampled_token_ids is not None
         if not blocking and not self.sampler_output_ready_event.query():
@@ -97,8 +105,15 @@ class ModelOutput:
                 input_metadata,
                 self.sampler_output,
                 pinned_sampled_token_buffer,
-                self.sampled_token_ids,
-            )
+                self.sampled_token_ids, self.logprobs)
+
+        # Erase the logprobs GPU-side tensor.
+        # Note that although _pythonize_sampler_output() runs in its
+        # own CUDA stream, nonetheless _pythonize_sampler_output()
+        # cannot return until Pythonization is complete; therefore
+        # we know that by the time the CPU reaches this point,
+        # `self.logprobs` is no longer needed.
+        self.logprobs = None
         return True
 
 
@@ -355,11 +370,12 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
                 ModelOutput(
                     output[0],
                     output_ready_event,
-                    output[0].sampled_token_ids,
-                    False,
-                )
-            )
-            # make sure we dont try to serialize any GPU tensors
+                    output[0].sampled_token_ids, False,
+                            output[0].logprobs))
+
+            # These GPU tensors are not required by multi-step;
+            # erase them to ensure they are not pythonized or
+            # transferred to CPU
             output[0].sampled_token_ids = None
             output[0].sampled_token_probs = None
             output[0].logprobs = None
@@ -464,14 +480,72 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
         return self._base_model_runner.vocab_size
 
 
+DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
+                                   Optional[List[SampleLogprobs]]]
+
+
+def deferred_pythonize_logprobs(
+    output: SamplerOutput,
+    sampling_metadata: SamplingMetadata,
+    logprobs_tensor: Optional[torch.Tensor],
+) -> DeferredLogprobsReturnType:
+    """Perform deferred logprob Pythonization.
+    1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
+    2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
+       utilizing  the Pythonized sampler result computed in step 1.
+    
+    These deferred computations are not required for single-step scheduling
+    or the `profile_run()` phase of multi-step scheduling.
+    Args:
+        output: sampler output (under deferred Pythonization)
+        sampling_metadata
+        
+    Returns:
+        prompt_logprobs (CPU), sample_logprobs (CPU)
+    """
+
+    # - Deferred pythonization of sample result
+    sampler_result = get_pythonized_sample_results(
+        output.deferred_sample_results_args)
+
+    # - Erase the GPU-side deferred sample_result
+    #   computation args to ensure it is never
+    #   pythonized or transferred to CPU
+    output.deferred_sample_results_args = None
+
+    # - Deferred pythonization of logprobs
+    (
+        prompt_logprobs,
+        sample_logprobs,
+    ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
+    assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
+    assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
+
+    return prompt_logprobs, sample_logprobs
+
+
 def _pythonize_sampler_output(
     model_input: StatefulModelInput,
     output: SamplerOutput,
     pinned_sampled_token_buffer: torch.Tensor,
     sampled_token_ids: torch.Tensor,
+    logprobs_tensor: Optional[torch.Tensor],
 ) -> None:
-    """This function is only called when the output tensors are ready.
-    See ModelOutput
+    """ This function is only called when the output tensors are ready. 
+    See :class:`ModelOutput`. 
+    
+    Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, 
+    adding a Pythonized output data structure
+    (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
+    Args:
+      model_input
+      output: sampler output
+      pinned_sampled_token_token_buffer: CPU-side pinned memory
+                                         (receives copy of
+                                         GPU-side token buffer.)
+      sampled_token_ids: GPU-side token buffer
+      logprobs_tensor: GPU-side tensor containing 
+                       logprobs computed during sampling
     """
     assert model_input.frozen_model_input is not None
     frozen_model_input = model_input.frozen_model_input
@@ -484,26 +558,70 @@ def _pythonize_sampler_output(
     # this will not block as the tensors are already on CPU
     samples_list = pinned_buffer.tolist()
     sampling_metadata = frozen_model_input.sampling_metadata
-    for seq_group, sample_result in zip(
-        sampling_metadata.seq_groups, samples_list
-    ):
+    skip_sampler_cpu_output = (
+        frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
+
+    # We are guaranteed output tensors are ready, so it is safe to
+    # pythonize the sampler output & obtain CPU-side logprobs.
+    #
+    # However this computation may be skipped entirely
+    # if no pythonization was deferred.
+    seq_groups = sampling_metadata.seq_groups
+    logprobs_are_requested = any([
+        sg.sampling_params.logprobs is not None
+        or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
+    ])
+    do_pythonize_logprobs = (skip_sampler_cpu_output
+                             and logprobs_are_requested)
+    (
+        prompt_logprobs,
+        sample_logprobs,
+    ) = (deferred_pythonize_logprobs(output, sampling_metadata,
+                                     logprobs_tensor)
+         if do_pythonize_logprobs else (None, None))
+
+    for sgdx, (seq_group,
+               sample_result) in enumerate(zip(seq_groups, samples_list)):
+
+        if do_pythonize_logprobs:
+            assert prompt_logprobs is not None
+            assert sample_logprobs is not None
+
+            (
+                group_prompt_logprobs,
+                group_sample_logprobs,
+            ) = (  # Utilize deferred pythonization results
+                prompt_logprobs[sgdx],
+                sample_logprobs[sgdx],
+            )
+        elif logprobs_are_requested:
+            (
+                group_prompt_logprobs,
+                group_sample_logprobs,
+            ) = (
+                # profile_run: use already-computed logprobs
+                output.outputs[sgdx].prompt_logprobs,
+                [sample.logprobs for sample in output.outputs[sgdx].samples])
         seq_ids = seq_group.seq_ids
         next_token_ids = sample_result
         parent_ids = [0]
         seq_outputs: List[SequenceOutput] = []
         if seq_group.sampling_params.logits_processors:
-            assert (
-                len(seq_group.sampling_params.logits_processors) == 0
-            ), "Logits Processors are not supported in multi-step decoding"
-        for parent_id, next_token_id in zip(parent_ids, next_token_ids):
-            # TODO(will): support logprobs
-            # Hard coded logprob
+            assert len(seq_group.sampling_params.logits_processors) == 0, (
+                "Logits Processors are not supported in multi-step decoding")
+        for tdx, (parent_id,
+                  next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
             seq_outputs.append(
-                SequenceOutput(
-                    seq_ids[parent_id],
-                    next_token_id,
-                    {next_token_id: Logprob(logprob=-1)},
-                )
-            )
-        output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
+                SequenceOutput(seq_ids[parent_id], next_token_id,
+                               (group_sample_logprobs[tdx]
+                                if logprobs_are_requested else {
+                                    next_token_id:
+                                    Logprob(logprob=float('inf'),
+                                            rank=None,
+                                            decoded_token=None)
+                                })))
+        output.outputs.append(
+            CompletionSequenceGroupOutput(
+                seq_outputs,
+                (group_prompt_logprobs if logprobs_are_requested else None)))
     assert len(output.outputs) > 0

+ 2 - 1
aphrodite/task_handler/multi_step_worker.py

@@ -4,8 +4,9 @@ from typing import Dict, List, Optional, Tuple
 
 import torch
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
 from aphrodite.task_handler.multi_step_model_runner import (
     MultiStepModelRunner, StatefulModelInput)

+ 2 - 1
aphrodite/task_handler/neuron_model_runner.py

@@ -7,10 +7,11 @@ from torch import nn
 
 from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig)
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import (is_pin_memory_available,
                                     make_tensor_with_pad)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.neuron import get_neuron_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,

+ 2 - 1
aphrodite/task_handler/openvino_model_runner.py

@@ -9,7 +9,8 @@ from aphrodite.attention.backends.openvino import OpenVINOAttentionMetadata
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import SequenceGroupMetadata
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.openvino import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,

+ 2 - 1
aphrodite/task_handler/openvino_worker.py

@@ -9,11 +9,12 @@ from aphrodite.attention import get_attn_backend
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import (broadcast_tensor_dict,
                                    ensure_model_parallel_initialized,
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.task_handler.openvino_model_runner import OpenVINOModelRunner
 from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
 

+ 2 - 2
aphrodite/task_handler/tpu_model_runner.py

@@ -16,10 +16,10 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      SchedulerConfig)
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        IntermediateTensors, Logprob,
-                                       SamplerOutput, SequenceGroupMetadata,
-                                       SequenceOutput)
+                                       SequenceGroupMetadata, SequenceOutput)
 from aphrodite.compilation.wrapper import (
     TorchCompileWrapperWithCustomDispacther)
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.task_handler.model_runner_base import (

+ 2 - 1
aphrodite/task_handler/worker.py

@@ -13,7 +13,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
 from aphrodite.common.sequence import (ExecuteModelRequest,
-                                       IntermediateTensors, SamplerOutput,
+                                       IntermediateTensors,
                                        SequenceGroupMetadata,
                                        SequenceGroupMetadataDelta)
 from aphrodite.distributed import (ensure_model_parallel_initialized,
@@ -22,6 +22,7 @@ from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    set_custom_all_reduce)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.platforms import current_platform
 from aphrodite.prompt_adapter.request import PromptAdapterRequest

+ 2 - 2
aphrodite/task_handler/worker_base.py

@@ -7,13 +7,13 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
 import torch
 from loguru import logger
 
-from aphrodite.common.sequence import (ExecuteModelRequest,
-                                       IntermediateTensors, SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.common.utils import (enable_trace_function_call_for_thread,
                                     update_environment_variables)
 from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
                                    get_tp_group)
 from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.platforms import current_platform
 from aphrodite.task_handler.model_runner_base import (BroadcastableModelInput,
                                                       ModelRunnerBase,

+ 2 - 1
aphrodite/task_handler/xpu_model_runner.py

@@ -13,11 +13,12 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+from aphrodite.common.sequence import (IntermediateTensors,
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
 from aphrodite.distributed import get_pp_group
 from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
+from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs, MultiModalRegistry)

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini