|
@@ -1,10 +1,12 @@
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
import itertools
|
|
import itertools
|
|
import warnings
|
|
import warnings
|
|
|
|
+from dataclasses import dataclass
|
|
from enum import IntEnum
|
|
from enum import IntEnum
|
|
from math import inf
|
|
from math import inf
|
|
-from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
+from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
+import msgspec
|
|
import torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
from loguru import logger
|
|
from loguru import logger
|
|
@@ -14,7 +16,8 @@ import aphrodite.common.envs as envs
|
|
from aphrodite.common.sampling_params import SamplingType
|
|
from aphrodite.common.sampling_params import SamplingType
|
|
from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|
from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|
PromptLogprobs, SampleLogprobs,
|
|
PromptLogprobs, SampleLogprobs,
|
|
- SamplerOutput, SequenceOutput)
|
|
|
|
|
|
+ SequenceOutput)
|
|
|
|
+from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
from aphrodite.triton_utils import HAS_TRITON
|
|
from aphrodite.triton_utils import HAS_TRITON
|
|
|
|
|
|
if HAS_TRITON:
|
|
if HAS_TRITON:
|
|
@@ -27,6 +30,115 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
# (num_token_ids, num_parent_ids) per sequence group.
|
|
# (num_token_ids, num_parent_ids) per sequence group.
|
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
|
|
|
|
|
|
|
+# Types of temporary data structures used for
|
|
|
|
+# computing sample_result
|
|
|
|
+SampleMetadataType = Dict[SamplingType, Tuple[List[int],
|
|
|
|
+ List[SequenceGroupToSample]]]
|
|
|
|
+MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
|
|
|
|
+SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Encapsulates temporary data structures for computing
|
|
|
|
+# sample_result.
|
|
|
|
+#
|
|
|
|
+# * For multi-step scheduling: must be returned
|
|
|
|
+# by `Sampler.forward()` and used later to compute the pythonized
|
|
|
|
+# sample_result
|
|
|
|
+#
|
|
|
|
+# * For single-step scheduling: consumed immediately
|
|
|
|
+# inside `Sampler.forward()` to compute pythonized sample_result.
|
|
|
|
+@dataclass
|
|
|
|
+class SampleResultArgsType:
|
|
|
|
+ sample_metadata: SampleMetadataType
|
|
|
|
+ multinomial_samples: MultinomialSamplesType
|
|
|
|
+ sample_results_dict: SampleResultsDictType
|
|
|
|
+ sampling_metadata: SamplingMetadata
|
|
|
|
+ greedy_samples: Optional[torch.Tensor]
|
|
|
|
+ beam_search_logprobs: Optional[torch.Tensor]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Union of non-deferred (single-step scheduling)
|
|
|
|
+# vs deferred (multi-step scheduling)
|
|
|
|
+# sample result types
|
|
|
|
+MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
|
|
|
|
+
|
|
|
|
+# Abbreviation of the _sample() return type
|
|
|
|
+SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class SamplerOutput(
|
|
|
|
+ msgspec.Struct,
|
|
|
|
+ omit_defaults=True, # type: ignore[call-arg]
|
|
|
|
+ array_like=True): # type: ignore[call-arg]
|
|
|
|
+ """For each sequence group, we generate a list of SequenceOutput object,
|
|
|
|
+ each of which contains one possible candidate for the next token.
|
|
|
|
+ This data structure implements methods, so it can be used like a list, but
|
|
|
|
+ also has optional fields for device tensors.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ outputs: List[CompletionSequenceGroupOutput]
|
|
|
|
+
|
|
|
|
+ # On-device tensor containing probabilities of each token.
|
|
|
|
+ sampled_token_probs: Optional[torch.Tensor] = None
|
|
|
|
+
|
|
|
|
+ # On-device tensor containing the logprobs of each token.
|
|
|
|
+ logprobs: Optional["torch.Tensor"] = None
|
|
|
|
+
|
|
|
|
+ # Holds either (1) the pythonized sampler result (single-step scheduling)
|
|
|
|
+ # or (2) what will be arguments for later deferred pythonization of the
|
|
|
|
+ # sampler result (muliti-step scheduling)
|
|
|
|
+ deferred_sample_results_args: Optional[SampleResultArgsType] = None
|
|
|
|
+
|
|
|
|
+ # On-device tensor containing the sampled token ids.
|
|
|
|
+ sampled_token_ids: Optional[torch.Tensor] = None
|
|
|
|
+ # CPU tensor containing the sampled token ids. Used during multi-step to
|
|
|
|
+ # return the sampled token ids from last rank to AsyncLLMEngine to be
|
|
|
|
+ # 'broadcasted' to all other PP ranks for next step.
|
|
|
|
+ sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
|
|
|
+
|
|
|
|
+ # Spec decode metrics populated by workers.
|
|
|
|
+ spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
|
|
|
+
|
|
|
|
+ # Optional last hidden states from the model.
|
|
|
|
+ hidden_states: Optional[torch.Tensor] = None
|
|
|
|
+
|
|
|
|
+ # Optional prefill hidden states from the model
|
|
|
|
+ # (used for models like EAGLE).
|
|
|
|
+ prefill_hidden_states: Optional[torch.Tensor] = None
|
|
|
|
+
|
|
|
|
+ # Time taken in the forward pass for this across all workers
|
|
|
|
+ model_forward_time: Optional[float] = None
|
|
|
|
+
|
|
|
|
+ # Time taken in the model execute function. This will include model forward,
|
|
|
|
+ # block/sync across workers, cpu-gpu sync time and sampling time.
|
|
|
|
+ model_execute_time: Optional[float] = None
|
|
|
|
+
|
|
|
|
+ def __getitem__(self, idx: int):
|
|
|
|
+ return self.outputs[idx]
|
|
|
|
+
|
|
|
|
+ def __setitem__(self, idx: int, value):
|
|
|
|
+ self.outputs[idx] = value
|
|
|
|
+
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return len(self.outputs)
|
|
|
|
+
|
|
|
|
+ def __eq__(self, other: object):
|
|
|
|
+ return isinstance(other,
|
|
|
|
+ self.__class__) and self.outputs == other.outputs
|
|
|
|
+
|
|
|
|
+ def __repr__(self) -> str:
|
|
|
|
+ """Show the shape of a tensor instead of its values to reduce noise.
|
|
|
|
+ """
|
|
|
|
+ sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
|
|
|
+ else self.sampled_token_probs.shape)
|
|
|
|
+ sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
|
|
|
+ self.sampled_token_ids.shape)
|
|
|
|
+ return (
|
|
|
|
+ f"SamplerOutput(outputs={self.outputs}, "
|
|
|
|
+ f"sampled_token_probs={sampled_token_probs_repr}, "
|
|
|
|
+ f"sampled_token_ids={sampled_token_ids_repr}, "
|
|
|
|
+ f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
|
|
|
+
|
|
# There isn't a "safe" temperature range for fp16 logits.
|
|
# There isn't a "safe" temperature range for fp16 logits.
|
|
# This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
|
|
# This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
|
|
# that this temperature well-uses the fp16 space after the logits are offset.
|
|
# that this temperature well-uses the fp16 space after the logits are offset.
|
|
@@ -135,6 +247,18 @@ class Sampler(nn.Module):
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
) -> Optional[SamplerOutput]:
|
|
"""
|
|
"""
|
|
|
|
+ Single-step scheduling:
|
|
|
|
+ * Perform GPU-side sampling computation & compute
|
|
|
|
+ GPU-side logprobs tensor
|
|
|
|
+ * Pythonize sampling result & logprobs tensor
|
|
|
|
+ Multi-step scheduling:
|
|
|
|
+ * Perform GPU-side sampling computation & compute
|
|
|
|
+ GPU-side logprobs tensor
|
|
|
|
+ * Defer Pythonization of sampling result & logprobs
|
|
|
|
+ tensor
|
|
|
|
+ * Encapsulate arguments required for deferred Pythonization
|
|
|
|
+ in the :class:`SamplerOutput` structure
|
|
|
|
+
|
|
Args:
|
|
Args:
|
|
logits: (num_tokens, vocab_size).
|
|
logits: (num_tokens, vocab_size).
|
|
sampling_metadata: Metadata for sampling.
|
|
sampling_metadata: Metadata for sampling.
|
|
@@ -425,7 +549,7 @@ class Sampler(nn.Module):
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
|
|
# Sample the next tokens.
|
|
# Sample the next tokens.
|
|
- sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
|
|
|
|
+ maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
probs,
|
|
probs,
|
|
logprobs,
|
|
logprobs,
|
|
sampling_metadata,
|
|
sampling_metadata,
|
|
@@ -435,20 +559,28 @@ class Sampler(nn.Module):
|
|
)
|
|
)
|
|
|
|
|
|
if self.include_gpu_probs_tensor:
|
|
if self.include_gpu_probs_tensor:
|
|
|
|
+ # Since we will defer sampler result Pythonization,
|
|
|
|
+ # preserve GPU-side tensors in support of later
|
|
|
|
+ # deferred pythonization of logprobs
|
|
assert maybe_sampled_tokens_tensor is not None
|
|
assert maybe_sampled_tokens_tensor is not None
|
|
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
|
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
|
else:
|
|
else:
|
|
|
|
+ # Since Pythonization has already happened, don't preserve
|
|
|
|
+ # GPU-side tensors.
|
|
on_device_tensors = None
|
|
on_device_tensors = None
|
|
|
|
|
|
# Get the logprobs query results.
|
|
# Get the logprobs query results.
|
|
prompt_logprobs = None
|
|
prompt_logprobs = None
|
|
sample_logprobs = None
|
|
sample_logprobs = None
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
- prompt_logprobs, sample_logprobs = _get_logprobs(
|
|
|
|
- logprobs, sampling_metadata, sample_results)
|
|
|
|
|
|
+ # Pythonize logprobs now (GPU -> CPU); do not defer.
|
|
|
|
+ assert not isinstance(maybe_deferred_sample_results,
|
|
|
|
+ SampleResultArgsType)
|
|
|
|
+ prompt_logprobs, sample_logprobs = get_logprobs(
|
|
|
|
+ logprobs, sampling_metadata, maybe_deferred_sample_results)
|
|
|
|
|
|
return _build_sampler_output(
|
|
return _build_sampler_output(
|
|
- sample_results,
|
|
|
|
|
|
+ maybe_deferred_sample_results,
|
|
sampling_metadata,
|
|
sampling_metadata,
|
|
prompt_logprobs,
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
sample_logprobs,
|
|
@@ -1205,6 +1337,57 @@ def _top_k_top_p_multinomial_with_kernels(
|
|
return batch_next_token_ids.view(-1, num_samples)
|
|
return batch_next_token_ids.view(-1, num_samples)
|
|
|
|
|
|
|
|
|
|
|
|
+def get_pythonized_sample_results(
|
|
|
|
+ sample_result_args: SampleResultArgsType) -> SampleResultType:
|
|
|
|
+ """This function consumes GPU-side sampler results and computes
|
|
|
|
+ Pythonized CPU-side sampler results (GPU -> CPU sync.)
|
|
|
|
+ Single-step scheduling: this function is invoked at sampling-time
|
|
|
|
+ for immediate Pythonization.
|
|
|
|
+ Multi-step scheduling: Pythonization is deferred until after multiple
|
|
|
|
+ GPU-side steps have been completed.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ sample_result_args: GPU-side inputs to the Pythonization process
|
|
|
|
+ Returns:
|
|
|
|
+ Pythonized sampler results
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ (
|
|
|
|
+ sample_metadata,
|
|
|
|
+ sampling_metadata,
|
|
|
|
+ greedy_samples,
|
|
|
|
+ multinomial_samples,
|
|
|
|
+ beam_search_logprobs,
|
|
|
|
+ sample_results_dict,
|
|
|
|
+ ) = (
|
|
|
|
+ sample_result_args.sample_metadata,
|
|
|
|
+ sample_result_args.sampling_metadata,
|
|
|
|
+ sample_result_args.greedy_samples,
|
|
|
|
+ sample_result_args.multinomial_samples,
|
|
|
|
+ sample_result_args.beam_search_logprobs,
|
|
|
|
+ sample_result_args.sample_results_dict,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ for sampling_type in SamplingType:
|
|
|
|
+ if sampling_type not in sample_metadata:
|
|
|
|
+ continue
|
|
|
|
+ (seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
|
|
|
+ if sampling_type == SamplingType.GREEDY:
|
|
|
|
+ sample_results = _greedy_sample(seq_groups, greedy_samples)
|
|
|
|
+ elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
|
|
|
+ sample_results = _random_sample(seq_groups,
|
|
|
|
+ multinomial_samples[sampling_type])
|
|
|
|
+ elif sampling_type == SamplingType.BEAM:
|
|
|
|
+ sample_results = _beam_search_sample(seq_groups,
|
|
|
|
+ beam_search_logprobs)
|
|
|
|
+ sample_results_dict.update(zip(seq_group_id, sample_results))
|
|
|
|
+
|
|
|
|
+ return [
|
|
|
|
+ sample_results_dict.get(i, ([], []))
|
|
|
|
+ for i in range(len(sampling_metadata.seq_groups))
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+
|
|
def _sample_with_torch(
|
|
def _sample_with_torch(
|
|
probs: torch.Tensor,
|
|
probs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
@@ -1212,7 +1395,18 @@ def _sample_with_torch(
|
|
sampling_tensors: SamplingTensors,
|
|
sampling_tensors: SamplingTensors,
|
|
include_gpu_probs_tensor: bool,
|
|
include_gpu_probs_tensor: bool,
|
|
modify_greedy_probs: bool,
|
|
modify_greedy_probs: bool,
|
|
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
|
|
|
|
|
+) -> SampleReturnType:
|
|
|
|
+ """Torch-oriented _sample() implementation.
|
|
|
|
+
|
|
|
|
+ Single-step scheduling:
|
|
|
|
+ * Perform GPU-side sampling computation
|
|
|
|
+ * Immediately Pythonize sampling result
|
|
|
|
+
|
|
|
|
+ Multi-step scheduling:
|
|
|
|
+ * Perform GPU-side sampling computation
|
|
|
|
+ * Defer Pythonization & preserve GPU-side
|
|
|
|
+ tensors required for Pythonization
|
|
|
|
+ """
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
@@ -1220,9 +1414,11 @@ def _sample_with_torch(
|
|
sampling_type = sampling_params.sampling_type
|
|
sampling_type = sampling_params.sampling_type
|
|
categorized_seq_group_ids[sampling_type].append(i)
|
|
categorized_seq_group_ids[sampling_type].append(i)
|
|
|
|
|
|
- sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
|
|
|
- sample_metadata = {}
|
|
|
|
- multinomial_samples = {}
|
|
|
|
|
|
+ sample_results_dict: SampleResultsDictType = {}
|
|
|
|
+ sample_metadata: SampleMetadataType = {}
|
|
|
|
+ multinomial_samples: MultinomialSamplesType = {}
|
|
|
|
+ greedy_samples: Optional[torch.Tensor] = None
|
|
|
|
+ beam_search_logprobs: Optional[torch.Tensor] = None
|
|
# Create output tensor for sampled token ids.
|
|
# Create output tensor for sampled token ids.
|
|
if include_gpu_probs_tensor:
|
|
if include_gpu_probs_tensor:
|
|
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
|
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
|
@@ -1293,32 +1489,29 @@ def _sample_with_torch(
|
|
else:
|
|
else:
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
|
|
|
|
- # GPU<->CPU sync happens in the loop below.
|
|
|
|
- # This also converts the sample output to Python objects.
|
|
|
|
|
|
+ # Encapsulate arguments for computing Pythonized sampler
|
|
|
|
+ # results, whether deferred or otherwise.
|
|
|
|
+ maybe_deferred_args = SampleResultArgsType(
|
|
|
|
+ sampling_metadata=sampling_metadata,
|
|
|
|
+ sample_metadata=sample_metadata,
|
|
|
|
+ multinomial_samples=multinomial_samples,
|
|
|
|
+ greedy_samples=greedy_samples,
|
|
|
|
+ beam_search_logprobs=beam_search_logprobs,
|
|
|
|
+ sample_results_dict=sample_results_dict)
|
|
|
|
+
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
- for sampling_type in SamplingType:
|
|
|
|
- if sampling_type not in sample_metadata:
|
|
|
|
- continue
|
|
|
|
- (seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
|
|
|
- if sampling_type == SamplingType.GREEDY:
|
|
|
|
- sample_results = _greedy_sample(seq_groups, greedy_samples)
|
|
|
|
- elif sampling_type in (SamplingType.RANDOM,
|
|
|
|
- SamplingType.RANDOM_SEED):
|
|
|
|
- sample_results = _random_sample(
|
|
|
|
- seq_groups, multinomial_samples[sampling_type])
|
|
|
|
- elif sampling_type == SamplingType.BEAM:
|
|
|
|
- sample_results = _beam_search_sample(seq_groups,
|
|
|
|
- beam_search_logprobs)
|
|
|
|
- sample_results_dict.update(zip(seq_group_id, sample_results))
|
|
|
|
-
|
|
|
|
- sample_results = [
|
|
|
|
- sample_results_dict.get(i, ([], []))
|
|
|
|
- for i in range(len(sampling_metadata.seq_groups))
|
|
|
|
- ]
|
|
|
|
|
|
+ # GPU<->CPU sync happens here.
|
|
|
|
+ # This also converts the sampler output to a Python object.
|
|
|
|
+ # Return Pythonized sampler result & sampled token ids
|
|
|
|
+ return get_pythonized_sample_results(
|
|
|
|
+ maybe_deferred_args), sampled_token_ids_tensor
|
|
else:
|
|
else:
|
|
- sample_results = []
|
|
|
|
-
|
|
|
|
- return sample_results, sampled_token_ids_tensor
|
|
|
|
|
|
+ # Defer sampler result Pythonization; return deferred
|
|
|
|
+ # Pythonization args & sampled token ids
|
|
|
|
+ return (
|
|
|
|
+ maybe_deferred_args,
|
|
|
|
+ sampled_token_ids_tensor,
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
def _sample_with_triton_kernel(
|
|
def _sample_with_triton_kernel(
|
|
@@ -1396,10 +1589,13 @@ def _sample_with_triton_kernel(
|
|
|
|
|
|
|
|
|
|
def _sample(
|
|
def _sample(
|
|
- probs: torch.Tensor, logprobs: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
|
|
|
- include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
|
|
|
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
|
|
|
|
|
+ probs: torch.Tensor,
|
|
|
|
+ logprobs: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+ sampling_tensors: SamplingTensors,
|
|
|
|
+ include_gpu_probs_tensor: bool,
|
|
|
|
+ modify_greedy_probs: bool,
|
|
|
|
+) -> SampleReturnType:
|
|
"""
|
|
"""
|
|
Args:
|
|
Args:
|
|
probs: (num_query_tokens_in_batch, num_vocab)
|
|
probs: (num_query_tokens_in_batch, num_vocab)
|
|
@@ -1441,7 +1637,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
|
return (x > vals[:, None]).long().sum(1).add_(1)
|
|
return (x > vals[:, None]).long().sum(1).add_(1)
|
|
|
|
|
|
|
|
|
|
-def _get_logprobs(
|
|
|
|
|
|
+def get_logprobs(
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
sample_results: List[Tuple[List[int], List[int]]],
|
|
sample_results: List[Tuple[List[int], List[int]]],
|
|
@@ -1755,7 +1951,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
|
|
|
|
|
|
|
|
|
def _build_sampler_output(
|
|
def _build_sampler_output(
|
|
- sample_results: SampleResultType,
|
|
|
|
|
|
+ maybe_deferred_sample_results: MaybeDeferredSampleResultType,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
|
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
|
sample_logprobs: Optional[List[SampleLogprobs]],
|
|
sample_logprobs: Optional[List[SampleLogprobs]],
|
|
@@ -1771,14 +1967,21 @@ def _build_sampler_output(
|
|
speculative decoding rejection sampling.
|
|
speculative decoding rejection sampling.
|
|
"""
|
|
"""
|
|
sampler_output: List[CompletionSequenceGroupOutput] = []
|
|
sampler_output: List[CompletionSequenceGroupOutput] = []
|
|
- if not skip_sampler_cpu_output:
|
|
|
|
|
|
+
|
|
|
|
+ if skip_sampler_cpu_output:
|
|
|
|
+ assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
|
|
|
|
+ deferred_sample_results_args = maybe_deferred_sample_results
|
|
|
|
+ else:
|
|
assert prompt_logprobs is not None
|
|
assert prompt_logprobs is not None
|
|
assert sample_logprobs is not None
|
|
assert sample_logprobs is not None
|
|
|
|
+ assert not isinstance(maybe_deferred_sample_results,
|
|
|
|
+ SampleResultArgsType)
|
|
|
|
+ deferred_sample_results_args = None
|
|
|
|
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
|
- sample_results, prompt_logprobs,
|
|
|
|
- sample_logprobs):
|
|
|
|
|
|
+ maybe_deferred_sample_results,
|
|
|
|
+ prompt_logprobs, sample_logprobs):
|
|
seq_ids = seq_group.seq_ids
|
|
seq_ids = seq_group.seq_ids
|
|
next_token_ids, parent_ids = sample_result
|
|
next_token_ids, parent_ids = sample_result
|
|
seq_outputs: List[SequenceOutput] = []
|
|
seq_outputs: List[SequenceOutput] = []
|
|
@@ -1802,7 +2005,7 @@ def _build_sampler_output(
|
|
sampled_token_probs=sampled_token_probs,
|
|
sampled_token_probs=sampled_token_probs,
|
|
sampled_token_ids=sampled_token_ids,
|
|
sampled_token_ids=sampled_token_ids,
|
|
logprobs=logprobs_tensor,
|
|
logprobs=logprobs_tensor,
|
|
- )
|
|
|
|
|
|
+ deferred_sample_results_args=deferred_sample_results_args)
|
|
|
|
|
|
|
|
|
|
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
|
|
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
|