|
@@ -1,225 +1,155 @@
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
|
|
+import itertools
|
|
|
|
+import math
|
|
from typing import Dict, List, Tuple, Optional
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
-import math
|
|
|
|
|
|
|
|
from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
OutputMetadata,
|
|
OutputMetadata,
|
|
SamplingTensors)
|
|
SamplingTensors)
|
|
-from aphrodite.modeling.megatron.communication_op import (
|
|
|
|
- tensor_model_parallel_gather)
|
|
|
|
from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
|
from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
|
SamplerOutput, SequenceData,
|
|
SamplerOutput, SequenceData,
|
|
SequenceGroupOutput, SequenceOutput)
|
|
SequenceGroupOutput, SequenceOutput)
|
|
from aphrodite.modeling.layers.ops.sample import sample as sample_triton
|
|
from aphrodite.modeling.layers.ops.sample import sample as sample_triton
|
|
-from aphrodite.common.utils import is_neuron
|
|
|
|
|
|
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
class Sampler(nn.Module):
|
|
"""Samples the next tokens from the model's outputs.
|
|
"""Samples the next tokens from the model's outputs.
|
|
-
|
|
|
|
This layer does the following:
|
|
This layer does the following:
|
|
1. Discard the hidden states that are not used for sampling (i.e., all
|
|
1. Discard the hidden states that are not used for sampling (i.e., all
|
|
tokens except the final one in each prompt).
|
|
tokens except the final one in each prompt).
|
|
2. Compute the logits for the next tokens.
|
|
2. Compute the logits for the next tokens.
|
|
- 3. Apply presence and frequency penalties.
|
|
|
|
- 4. Apply temperature scaling.
|
|
|
|
- 5. Apply top-p and top-k truncation.
|
|
|
|
- 6. Sample the next tokens.
|
|
|
|
|
|
+ 3. Apply all the different sampler functions in the specified order.
|
|
|
|
+ 4. Sample the next tokens.
|
|
Here, each sequence group within the batch can have different sampling
|
|
Here, each sequence group within the batch can have different sampling
|
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
|
|
+ The structure of the logits tensor is coupled with the seq_groups in
|
|
|
|
+ sampling_metadata. Typically, each sequence in each seq_group has one row in
|
|
|
|
+ logits for the next token to be sampled; however, for a seq_group with a
|
|
|
|
+ prompt request with the prompt_logprobs sampling parameter, there are rows
|
|
|
|
+ in logits for each token in the input prompt.
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self,
|
|
|
|
- vocab_size: int,
|
|
|
|
- org_vocab_size: Optional[int] = None) -> None:
|
|
|
|
|
|
+ def __init__(self):
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.vocab_size = vocab_size
|
|
|
|
- # Transformers-neuronx generates outputs as logits directly
|
|
|
|
- self.logits_as_hidden_states = is_neuron()
|
|
|
|
- # original vocabulary size (without LoRA).
|
|
|
|
- self.org_vocab_size = org_vocab_size or vocab_size
|
|
|
|
-
|
|
|
|
- def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
|
|
- embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
|
|
- # Get the logits for the next tokens.
|
|
|
|
- logits = torch.matmul(hidden_states, embedding.t())
|
|
|
|
- if embedding_bias is not None:
|
|
|
|
- logits += embedding_bias
|
|
|
|
- logits = tensor_model_parallel_gather(logits)
|
|
|
|
- # Remove paddings in vocab (if any).
|
|
|
|
- if logits is not None:
|
|
|
|
- logits = logits[:, :self.org_vocab_size]
|
|
|
|
- return logits
|
|
|
|
-
|
|
|
|
- def forward(
|
|
|
|
- self,
|
|
|
|
- embedding: torch.Tensor,
|
|
|
|
- hidden_states: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
- embedding_bias: Optional[torch.Tensor] = None,
|
|
|
|
- ) -> Optional[SamplerOutput]:
|
|
|
|
- # Get the hidden states that we use for sampling.
|
|
|
|
- if self.logits_as_hidden_states:
|
|
|
|
- logits = hidden_states
|
|
|
|
- else:
|
|
|
|
- hidden_states = _prune_hidden_states(hidden_states,
|
|
|
|
- sampling_metadata)
|
|
|
|
-
|
|
|
|
- # Get the logits for the next tokens.
|
|
|
|
- logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
|
|
|
|
|
|
|
- return _perform_sampling(logits, sampling_metadata)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-# FIXME: This is a hack for the missing GPU blocks. This should be removed
|
|
|
|
-# once a proper fix is implemented.
|
|
|
|
-class QuantSampler(nn.Module):
|
|
|
|
- """Samples the next tokens from the model's outputs.
|
|
|
|
-
|
|
|
|
- This layer does the following:
|
|
|
|
- 1. Discard the hidden states that are not used for sampling (i.e., all
|
|
|
|
- tokens except the final one in each prompt).
|
|
|
|
- 2. Compute the logits for the next tokens.
|
|
|
|
- 3. Apply presence and frequency penalties.
|
|
|
|
- 4. Apply temperature scaling.
|
|
|
|
- 5. Apply top-p and top-k truncation.
|
|
|
|
- 6. Sample the next tokens.
|
|
|
|
- Here, each sequence group within the batch can have different sampling
|
|
|
|
- parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- def __init__(self,
|
|
|
|
- vocab_size: int,
|
|
|
|
- org_vocab_size: Optional[int] = None) -> None:
|
|
|
|
- super().__init__()
|
|
|
|
- self.vocab_size = vocab_size
|
|
|
|
- # original vocabulary size (without LoRA).
|
|
|
|
- self.org_vocab_size = org_vocab_size or vocab_size
|
|
|
|
-
|
|
|
|
- def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
|
|
- embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
|
|
- # Get the logits for the next tokens.
|
|
|
|
- logits = torch.matmul(hidden_states, embedding.t())
|
|
|
|
- if embedding_bias is not None:
|
|
|
|
- logits += embedding_bias
|
|
|
|
- logits = tensor_model_parallel_gather(logits)
|
|
|
|
- # Remove paddings in vocab (if any).
|
|
|
|
- if logits is not None:
|
|
|
|
- logits = logits[:, :self.org_vocab_size]
|
|
|
|
- return logits
|
|
|
|
|
|
+ # Whether or not the SamplerOutput should have on-device tensors
|
|
|
|
+ # containing the sampled token ids and probabilities. This is used by
|
|
|
|
+ # speculative decoding.
|
|
|
|
+ self.include_gpu_probs_tensor = False
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
) -> Optional[SamplerOutput]:
|
|
- # Get the hidden states that we use for sampling.
|
|
|
|
- logits = _prune_hidden_states(logits, sampling_metadata)
|
|
|
|
- logits = tensor_model_parallel_gather(logits)
|
|
|
|
- # Remove paddings in vocab (if any).
|
|
|
|
- if logits is not None:
|
|
|
|
- logits = logits[:, :self.vocab_size]
|
|
|
|
-
|
|
|
|
- return _perform_sampling(logits, sampling_metadata)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def _perform_sampling(
|
|
|
|
- logits: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
|
|
|
- # Only perform sampling in the driver worker.
|
|
|
|
- # Note: `_get_logits` is still distributed across TP workers because
|
|
|
|
- # the `embedding` weight is distributed across TP workers.
|
|
|
|
- # TODO: Change the get_logits part to a separate stage.
|
|
|
|
- if not sampling_metadata.perform_sampling:
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
- assert logits is not None
|
|
|
|
- _, vocab_size = logits.shape
|
|
|
|
-
|
|
|
|
- output_metadata = OutputMetadata()
|
|
|
|
-
|
|
|
|
- # Apply logits processors (if any)
|
|
|
|
- logits = _apply_logits_processors(logits, sampling_metadata)
|
|
|
|
-
|
|
|
|
- # Prepare sampling tensors with pinned memory to avoid blocking.
|
|
|
|
- sampling_tensors = SamplingTensors.from_sampling_metadata(
|
|
|
|
- sampling_metadata, vocab_size, logits.device, logits.dtype)
|
|
|
|
-
|
|
|
|
- if sampling_tensors.do_penalties:
|
|
|
|
- logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
|
|
- sampling_tensors.output_tokens,
|
|
|
|
- sampling_tensors.pres_penalties,
|
|
|
|
- sampling_tensors.freq_penalties,
|
|
|
|
- sampling_tensors.rep_penalties)
|
|
|
|
-
|
|
|
|
- if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
|
|
|
|
- logits = _apply_temperature(logits, sampling_tensors.temperatures,
|
|
|
|
- sampling_tensors.dynatemp_mins,
|
|
|
|
- sampling_tensors.dynatemp_maxs,
|
|
|
|
- sampling_tensors.dynatemp_exps)
|
|
|
|
-
|
|
|
|
- if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
|
|
|
|
- or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
|
|
|
|
- logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
|
|
|
|
- sampling_tensors.top_ks,
|
|
|
|
- sampling_tensors.top_as,
|
|
|
|
- sampling_tensors.min_ps)
|
|
|
|
-
|
|
|
|
- if sampling_tensors.do_tfss:
|
|
|
|
- logits = _apply_tfs(logits, sampling_tensors.tfss)
|
|
|
|
- if sampling_tensors.do_eta_cutoffs:
|
|
|
|
- logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
|
|
|
|
- if sampling_tensors.do_epsilon_cutoffs:
|
|
|
|
- logits = _apply_epsilon_cutoff(logits,
|
|
|
|
- sampling_tensors.epsilon_cutoffs)
|
|
|
|
- if sampling_tensors.do_typical_ps:
|
|
|
|
- logits = _apply_typical_sampling(logits, sampling_tensors.typical_ps)
|
|
|
|
-
|
|
|
|
- if sampling_tensors.do_quadratic:
|
|
|
|
- logits = _apply_quadratic_sampling(logits,
|
|
|
|
- sampling_tensors.smoothing_indices,
|
|
|
|
- sampling_tensors.smoothing_factors,
|
|
|
|
- sampling_tensors.smoothing_curves)
|
|
|
|
-
|
|
|
|
- banned_tokens = _get_custom_token_bans(sampling_metadata)
|
|
|
|
- assert len(banned_tokens) == logits.shape[0]
|
|
|
|
- logits = _apply_token_bans(logits, banned_tokens)
|
|
|
|
- if sampling_tensors.do_mirostat:
|
|
|
|
- logits = _apply_mirostat_v2(logits, sampling_tensors)
|
|
|
|
-
|
|
|
|
- # We use float32 for probabilities and log probabilities.
|
|
|
|
- # Compute the probabilities.
|
|
|
|
- probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
- # Compute the log probabilities.
|
|
|
|
- # Use log_softmax to ensure numerical stability.
|
|
|
|
- logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
-
|
|
|
|
- # Sample the next tokens.
|
|
|
|
- sample_results = _sample(probs, logprobs, sampling_metadata,
|
|
|
|
- sampling_tensors)
|
|
|
|
-
|
|
|
|
- if sampling_tensors.do_mirostat:
|
|
|
|
- _mirostat_store_args(logits, sampling_tensors, sample_results,
|
|
|
|
- sampling_metadata, output_metadata)
|
|
|
|
- # Get the logprobs query results.
|
|
|
|
- prompt_logprobs, sample_logprobs = _get_logprobs(logprobs,
|
|
|
|
- sampling_metadata,
|
|
|
|
- sample_results)
|
|
|
|
- return _build_sampler_output(sample_results, sampling_metadata,
|
|
|
|
- prompt_logprobs, sample_logprobs,
|
|
|
|
- output_metadata)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def _prune_hidden_states(
|
|
|
|
- hidden_states: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
-) -> torch.Tensor:
|
|
|
|
- return hidden_states.index_select(0,
|
|
|
|
- sampling_metadata.selected_token_indices)
|
|
|
|
|
|
+ assert logits is not None
|
|
|
|
+ _, vocab_size = logits.shape
|
|
|
|
+ output_metadata = OutputMetadata()
|
|
|
|
+ # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
|
|
|
+ # have not been generated yet
|
|
|
|
+ logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
|
|
|
+
|
|
|
|
+ # Prepare sampling tensors with pinned memory to avoid blocking.
|
|
|
|
+ sampling_tensors = SamplingTensors.from_sampling_metadata(
|
|
|
|
+ sampling_metadata, vocab_size, logits.device, logits.dtype)
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_penalties:
|
|
|
|
+ logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
|
|
+ sampling_tensors.output_tokens,
|
|
|
|
+ sampling_tensors.pres_penalties,
|
|
|
|
+ sampling_tensors.freq_penalties,
|
|
|
|
+ sampling_tensors.rep_penalties)
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
|
|
|
|
+ logits = _apply_temperature(logits, sampling_tensors.temperatures,
|
|
|
|
+ sampling_tensors.dynatemp_mins,
|
|
|
|
+ sampling_tensors.dynatemp_maxs,
|
|
|
|
+ sampling_tensors.dynatemp_exps)
|
|
|
|
+
|
|
|
|
+ if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
|
|
|
|
+ or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
|
|
|
|
+ logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
|
|
|
|
+ sampling_tensors.top_ks,
|
|
|
|
+ sampling_tensors.top_as,
|
|
|
|
+ sampling_tensors.min_ps)
|
|
|
|
+ if sampling_tensors.do_tfss:
|
|
|
|
+ logits = _apply_tfs(logits, sampling_tensors.tfss)
|
|
|
|
+ if sampling_tensors.do_eta_cutoffs:
|
|
|
|
+ logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
|
|
|
|
+ if sampling_tensors.do_epsilon_cutoffs:
|
|
|
|
+ logits = _apply_epsilon_cutoff(logits,
|
|
|
|
+ sampling_tensors.epsilon_cutoffs)
|
|
|
|
+ if sampling_tensors.do_typical_ps:
|
|
|
|
+ logits = _apply_typical_sampling(logits,
|
|
|
|
+ sampling_tensors.typical_ps)
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_quadratic:
|
|
|
|
+ logits = _apply_quadratic_sampling(
|
|
|
|
+ logits, sampling_tensors.smoothing_indices,
|
|
|
|
+ sampling_tensors.smoothing_factors,
|
|
|
|
+ sampling_tensors.smoothing_curves)
|
|
|
|
+
|
|
|
|
+ banned_tokens = _get_custom_token_bans(sampling_metadata)
|
|
|
|
+ assert len(banned_tokens) == logits.shape[0]
|
|
|
|
+ logits = _apply_token_bans(logits, banned_tokens)
|
|
|
|
+ if sampling_tensors.do_mirostat:
|
|
|
|
+ logits = _apply_mirostat_v2(logits, sampling_tensors)
|
|
|
|
+
|
|
|
|
+ # We use float32 for probabilities and log probabilities.
|
|
|
|
+ # Compute the probabilities.
|
|
|
|
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
+ # Compute the log probabilities.
|
|
|
|
+ # Use log_softmax to ensure numerical stability.
|
|
|
|
+ logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
+
|
|
|
|
+ # Sample the next tokens.
|
|
|
|
+ # sample_results = _sample(probs, logprobs, sampling_metadata,
|
|
|
|
+ # sampling_tensors)
|
|
|
|
+ sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
|
|
+ probs,
|
|
|
|
+ logprobs,
|
|
|
|
+ sampling_metadata,
|
|
|
|
+ sampling_tensors,
|
|
|
|
+ include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
|
|
|
+ modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if self.include_gpu_probs_tensor:
|
|
|
|
+ assert maybe_sampled_tokens_tensor is not None
|
|
|
|
+ sampled_tokens_tensor = maybe_sampled_tokens_tensor
|
|
|
|
+ on_device_tensors = (probs, sampled_tokens_tensor)
|
|
|
|
+ else:
|
|
|
|
+ on_device_tensors = None
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_mirostat:
|
|
|
|
+ _mirostat_store_args(logits, sampling_tensors, sample_results,
|
|
|
|
+ sampling_metadata, output_metadata)
|
|
|
|
+ # Get the logprobs query results.
|
|
|
|
+ prompt_logprobs, sample_logprobs = _get_logprobs(
|
|
|
|
+ logprobs, sampling_metadata, sample_results)
|
|
|
|
+ # return _build_sampler_output(sample_results, sampling_metadata,
|
|
|
|
+ # prompt_logprobs, sample_logprobs,
|
|
|
|
+ # output_metadata)
|
|
|
|
+ return _build_sampler_output(sample_results, sampling_metadata,
|
|
|
|
+ prompt_logprobs, sample_logprobs,
|
|
|
|
+ output_metadata, on_device_tensors)
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def _should_modify_greedy_probs_inplace(self) -> bool:
|
|
|
|
+ """Whether or not the sampler should modify the probability distribution
|
|
|
|
+ of greedily-sampled tokens such that multinomial sampling would sample
|
|
|
|
+ the greedily-sampled token.
|
|
|
|
+ In other words, if True then we set the probability of the greedily-
|
|
|
|
+ sampled token to 1.
|
|
|
|
+ This is used by speculative decoding, which requires that the sampling
|
|
|
|
+ method be encoded into the probability distribution.
|
|
|
|
+ """
|
|
|
|
+ # Modify greedy probs if include_gpu_probs_tensor is set.
|
|
|
|
+ return self.include_gpu_probs_tensor
|
|
|
|
|
|
|
|
|
|
def _get_bin_counts_and_mask(
|
|
def _get_bin_counts_and_mask(
|
|
@@ -255,35 +185,6 @@ def _get_custom_token_bans(
|
|
return banned_tokens
|
|
return banned_tokens
|
|
|
|
|
|
|
|
|
|
-def _apply_logits_processors(
|
|
|
|
- logits: torch.Tensor,
|
|
|
|
- metadata: SamplingMetadata,
|
|
|
|
-) -> torch.Tensor:
|
|
|
|
- assert metadata.seq_groups is not None
|
|
|
|
- assert metadata.prompt_lens is not None
|
|
|
|
- assert metadata.seq_data is not None
|
|
|
|
- seq_offset = 0
|
|
|
|
- for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
|
|
|
|
- seq_size = len(seq_ids)
|
|
|
|
- output_tokens = []
|
|
|
|
- if (i < metadata.num_prompts
|
|
|
|
- and sampling_params.prompt_logprobs is not None):
|
|
|
|
- prompt_seqs = metadata.prompt_lens[i] - 1
|
|
|
|
- seq_size += prompt_seqs
|
|
|
|
- output_tokens.extend([[]] * prompt_seqs)
|
|
|
|
- seq_end = seq_offset + seq_size
|
|
|
|
-
|
|
|
|
- if sampling_params.logits_processors:
|
|
|
|
- output_tokens.extend(metadata.seq_data[sid].output_token_ids
|
|
|
|
- for sid in seq_ids)
|
|
|
|
- for proc in sampling_params.logits_processors:
|
|
|
|
- proc(logits[seq_offset:seq_end], output_tokens)
|
|
|
|
-
|
|
|
|
- seq_offset = seq_end
|
|
|
|
-
|
|
|
|
- return logits
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
@@ -316,6 +217,44 @@ def _apply_token_bans(logits: torch.Tensor,
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
+def _apply_min_tokens_penalty(
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
+ assert sampling_metadata.seq_groups is not None
|
|
|
|
+ assert sampling_metadata.seq_data is not None
|
|
|
|
+ # list of indices in logits that will be set to -inf
|
|
|
|
+ logits_to_penalize = []
|
|
|
|
+ start_idx = 0
|
|
|
|
+ for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
|
|
|
+ min_tokens = sampling_params.min_tokens
|
|
|
|
+ if min_tokens > 0:
|
|
|
|
+ seqs_to_penalize = []
|
|
|
|
+ for i, seq_id in enumerate(seq_ids):
|
|
|
|
+ seq_data = sampling_metadata.seq_data[seq_id]
|
|
|
|
+ if len(seq_data.output_token_ids) < min_tokens:
|
|
|
|
+ seqs_to_penalize.append(i)
|
|
|
|
+
|
|
|
|
+ if seqs_to_penalize:
|
|
|
|
+ # convert to the index into logits
|
|
|
|
+ seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
|
|
|
+ # use set() to remove any duplicates
|
|
|
|
+ token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
|
|
|
+ [sampling_params.eos_token_id])
|
|
|
|
+ # itertools.product pairs each seq index with every token id
|
|
|
|
+ logits_to_penalize.extend(
|
|
|
|
+ itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
|
|
|
+
|
|
|
|
+ start_idx += len(seq_ids)
|
|
|
|
+
|
|
|
|
+ if logits_to_penalize:
|
|
|
|
+ # use zip and * to group indices along each dimension
|
|
|
|
+ # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
|
|
|
+ logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
|
|
|
+
|
|
|
|
+ return logits
|
|
|
|
+
|
|
|
|
+
|
|
def _apply_alphabet_soup(
|
|
def _apply_alphabet_soup(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
p: torch.Tensor,
|
|
p: torch.Tensor,
|
|
@@ -645,6 +584,7 @@ def _multinomial(
|
|
if seq_groups is None:
|
|
if seq_groups is None:
|
|
q.exponential_()
|
|
q.exponential_()
|
|
else:
|
|
else:
|
|
|
|
+ assert generators is not None
|
|
sample_idx = 0
|
|
sample_idx = 0
|
|
for (seq_ids, _), generator in zip(seq_groups, generators):
|
|
for (seq_ids, _), generator in zip(seq_groups, generators):
|
|
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
|
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
|
@@ -657,7 +597,9 @@ def _sample_with_torch(
|
|
probs: torch.Tensor,
|
|
probs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
-) -> List[Tuple[List[int], List[int]]]:
|
|
|
|
|
|
+ include_gpu_probs_tensor: bool,
|
|
|
|
+ modify_greedy_probs: bool,
|
|
|
|
+) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
|
"""Returns list of (selected_tokens, parent_seq_ids) tuples
|
|
"""Returns list of (selected_tokens, parent_seq_ids) tuples
|
|
corresponding to sampling_metadata.seq_groups."""
|
|
corresponding to sampling_metadata.seq_groups."""
|
|
assert sampling_metadata.seq_groups is not None
|
|
assert sampling_metadata.seq_groups is not None
|
|
@@ -674,6 +616,15 @@ def _sample_with_torch(
|
|
sample_metadata = {}
|
|
sample_metadata = {}
|
|
multinomial_samples = {}
|
|
multinomial_samples = {}
|
|
|
|
|
|
|
|
+ # Create output tensor for sampled token ids.
|
|
|
|
+ if include_gpu_probs_tensor:
|
|
|
|
+ sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
|
|
|
+ 1,
|
|
|
|
+ dtype=torch.long,
|
|
|
|
+ device=logprobs.device)
|
|
|
|
+ else:
|
|
|
|
+ sampled_token_ids_tensor = None
|
|
|
|
+
|
|
# Counterintuitively, having two loops here is actually faster.
|
|
# Counterintuitively, having two loops here is actually faster.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
for sampling_type, sample_indices in categorized_sample_indices.items():
|
|
for sampling_type, sample_indices in categorized_sample_indices.items():
|
|
@@ -685,9 +636,23 @@ def _sample_with_torch(
|
|
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
|
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
|
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
|
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
|
is_prompts, sample_indices)
|
|
is_prompts, sample_indices)
|
|
|
|
+ long_sample_indices = sample_indices.long()
|
|
if sampling_type == SamplingType.GREEDY:
|
|
if sampling_type == SamplingType.GREEDY:
|
|
- greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
|
|
|
|
|
+ greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
|
dim=-1)
|
|
dim=-1)
|
|
|
|
+
|
|
|
|
+ if include_gpu_probs_tensor:
|
|
|
|
+ # Store sampled tokens in output tensor.
|
|
|
|
+ sampled_token_ids_tensor[
|
|
|
|
+ long_sample_indices] = greedy_samples.unsqueeze(-1)
|
|
|
|
+
|
|
|
|
+ if modify_greedy_probs:
|
|
|
|
+ # If required, modify the probabilities such that sampling from
|
|
|
|
+ # the modified distribution would always sample the argmax
|
|
|
|
+ # token id.
|
|
|
|
+ _modify_greedy_probs_inplace(logprobs, probs,
|
|
|
|
+ long_sample_indices,
|
|
|
|
+ greedy_samples)
|
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
|
max_best_of_in_batch = 1
|
|
max_best_of_in_batch = 1
|
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
|
@@ -700,14 +665,20 @@ def _sample_with_torch(
|
|
"generators": sampling_metadata.generators,
|
|
"generators": sampling_metadata.generators,
|
|
}
|
|
}
|
|
multinomial_samples[sampling_type] = _multinomial(
|
|
multinomial_samples[sampling_type] = _multinomial(
|
|
- probs[sample_indices.long()], max_best_of_in_batch,
|
|
|
|
|
|
+ probs[long_sample_indices], max_best_of_in_batch,
|
|
**seeded_args)
|
|
**seeded_args)
|
|
|
|
+
|
|
|
|
+ if include_gpu_probs_tensor:
|
|
|
|
+ # Store sampled tokens in output tensor.
|
|
|
|
+ sampled_token_ids_tensor[
|
|
|
|
+ long_sample_indices] = multinomial_samples[sampling_type]
|
|
elif sampling_type == SamplingType.BEAM:
|
|
elif sampling_type == SamplingType.BEAM:
|
|
beam_search_logprobs = logprobs[sample_indices]
|
|
beam_search_logprobs = logprobs[sample_indices]
|
|
else:
|
|
else:
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
|
|
|
|
# GPU<->CPU sync happens in the loop below.
|
|
# GPU<->CPU sync happens in the loop below.
|
|
|
|
+ # This also converts the sample output to Python objects.
|
|
|
|
|
|
for sampling_type, metadata in sample_metadata.items():
|
|
for sampling_type, metadata in sample_metadata.items():
|
|
seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
|
|
seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
|
|
@@ -726,7 +697,7 @@ def _sample_with_torch(
|
|
sample_results_dict[i]
|
|
sample_results_dict[i]
|
|
for i in range(len(sampling_metadata.seq_groups))
|
|
for i in range(len(sampling_metadata.seq_groups))
|
|
]
|
|
]
|
|
- return sample_results
|
|
|
|
|
|
+ return sample_results, sampled_token_ids_tensor
|
|
|
|
|
|
|
|
|
|
def _sample_with_triton_kernel(
|
|
def _sample_with_triton_kernel(
|
|
@@ -736,6 +707,7 @@ def _sample_with_triton_kernel(
|
|
sampling_tensors: SamplingTensors,
|
|
sampling_tensors: SamplingTensors,
|
|
) -> List[Tuple[List[int], List[int]]]:
|
|
) -> List[Tuple[List[int], List[int]]]:
|
|
assert sampling_metadata.seq_groups is not None
|
|
assert sampling_metadata.seq_groups is not None
|
|
|
|
+ assert sampling_metadata.categorized_sample_indices is not None
|
|
assert sampling_metadata.seq_data is not None
|
|
assert sampling_metadata.seq_data is not None
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
@@ -750,7 +722,6 @@ def _sample_with_triton_kernel(
|
|
|
|
|
|
# Counterintuitively, having two loops here is actually faster.
|
|
# Counterintuitively, having two loops here is actually faster.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
- assert categorized_sample_indices is not None
|
|
|
|
for sampling_type, sample_indices in categorized_sample_indices.items():
|
|
for sampling_type, sample_indices in categorized_sample_indices.items():
|
|
sampled_token_indices = sample_indices[:, 1]
|
|
sampled_token_indices = sample_indices[:, 1]
|
|
sample_indices = sample_indices[:, 0]
|
|
sample_indices = sample_indices[:, 0]
|
|
@@ -812,18 +783,40 @@ def _sample_with_triton_kernel(
|
|
|
|
|
|
|
|
|
|
def _sample(
|
|
def _sample(
|
|
- probs: torch.Tensor,
|
|
|
|
- logprobs: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
- sampling_tensors: SamplingTensors,
|
|
|
|
-) -> List[Tuple[List[int], List[int]]]:
|
|
|
|
- return _sample_with_torch(probs, logprobs, sampling_metadata)
|
|
|
|
|
|
+ probs: torch.Tensor, logprobs: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
|
|
|
+ include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
|
|
|
+) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
|
|
|
+ return _sample_with_torch(
|
|
|
|
+ probs,
|
|
|
|
+ logprobs,
|
|
|
|
+ sampling_metadata,
|
|
|
|
+ include_gpu_probs_tensor=include_gpu_probs_tensor,
|
|
|
|
+ modify_greedy_probs=modify_greedy_probs,
|
|
|
|
+ )
|
|
|
|
|
|
# TODO: Enable once Triton kernel & associated code is faster.
|
|
# TODO: Enable once Triton kernel & associated code is faster.
|
|
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
|
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
|
# sampling_tensors)
|
|
# sampling_tensors)
|
|
|
|
|
|
|
|
|
|
|
|
+def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
|
|
|
+ """
|
|
|
|
+ This function calculates the ranks of the chosen tokens in a logprob tensor.
|
|
|
|
+ Args:
|
|
|
|
+ x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
|
|
|
+ where N is the no. of tokens and M is the vocab dim.
|
|
|
|
+ indices (torch.Tensor): List of chosen token indices.
|
|
|
|
+ Returns:
|
|
|
|
+ torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
|
|
|
+ Each element in the returned tensor represents the rank
|
|
|
|
+ of the chosen token in the input logprob tensor.
|
|
|
|
+ """
|
|
|
|
+ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
|
|
|
+ indices]
|
|
|
|
+ return (x > vals[:, None]).long().sum(1).add_(1)
|
|
|
|
+
|
|
|
|
+
|
|
def _get_logprobs(
|
|
def _get_logprobs(
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
@@ -836,7 +829,8 @@ def _get_logprobs(
|
|
# Prepare query indices
|
|
# Prepare query indices
|
|
batched_logprobs_query_seq_indices: List[int] = []
|
|
batched_logprobs_query_seq_indices: List[int] = []
|
|
batched_logprobs_query_token_indices: List[int] = []
|
|
batched_logprobs_query_token_indices: List[int] = []
|
|
- largest_num_logprobs = 0
|
|
|
|
|
|
+ # at least get one logprob for each token
|
|
|
|
+ largest_num_logprobs = 1
|
|
sample_idx = 0
|
|
sample_idx = 0
|
|
for i, (seq_group, sample_result) in enumerate(
|
|
for i, (seq_group, sample_result) in enumerate(
|
|
zip(sampling_metadata.seq_groups, sample_results)):
|
|
zip(sampling_metadata.seq_groups, sample_results)):
|
|
@@ -864,12 +858,18 @@ def _get_logprobs(
|
|
sample_idx += num_parent_seqs
|
|
sample_idx += num_parent_seqs
|
|
assert sample_idx == logprobs.size(0)
|
|
assert sample_idx == logprobs.size(0)
|
|
|
|
|
|
|
|
+ batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
|
|
|
+ batched_logprobs_query_seq_indices, device=logprobs.device)
|
|
|
|
+ batched_logprobs_query_token_indices_gpu = torch.tensor(
|
|
|
|
+ batched_logprobs_query_token_indices, device=logprobs.device)
|
|
# Batched query for logprobs of selected token
|
|
# Batched query for logprobs of selected token
|
|
batched_logprobs_query_result = logprobs[[
|
|
batched_logprobs_query_result = logprobs[[
|
|
- batched_logprobs_query_seq_indices,
|
|
|
|
- batched_logprobs_query_token_indices
|
|
|
|
|
|
+ batched_logprobs_query_seq_indices_gpu,
|
|
|
|
+ batched_logprobs_query_token_indices_gpu
|
|
]]
|
|
]]
|
|
-
|
|
|
|
|
|
+ batched_ranks_query_result = _get_ranks(
|
|
|
|
+ logprobs[batched_logprobs_query_seq_indices_gpu],
|
|
|
|
+ batched_logprobs_query_token_indices_gpu)
|
|
# Batched query for logprobs of topk tokens
|
|
# Batched query for logprobs of topk tokens
|
|
if largest_num_logprobs > 0:
|
|
if largest_num_logprobs > 0:
|
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
|
@@ -882,6 +882,8 @@ def _get_logprobs(
|
|
|
|
|
|
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
|
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
|
|
|
|
|
|
|
+ batched_ranks_query_result = batched_ranks_query_result.cpu()
|
|
|
|
+
|
|
# Gather results
|
|
# Gather results
|
|
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
|
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
|
result_sample_logprobs: List[SampleLogprobs] = []
|
|
result_sample_logprobs: List[SampleLogprobs] = []
|
|
@@ -891,7 +893,6 @@ def _get_logprobs(
|
|
zip(sampling_metadata.seq_groups, sample_results)):
|
|
zip(sampling_metadata.seq_groups, sample_results)):
|
|
seq_ids, sampling_params = seq_group
|
|
seq_ids, sampling_params = seq_group
|
|
next_token_ids, parent_ids = sample_result
|
|
next_token_ids, parent_ids = sample_result
|
|
-
|
|
|
|
# Prompt logprobs
|
|
# Prompt logprobs
|
|
if (i < sampling_metadata.num_prompts
|
|
if (i < sampling_metadata.num_prompts
|
|
and sampling_params.prompt_logprobs is not None):
|
|
and sampling_params.prompt_logprobs is not None):
|
|
@@ -902,22 +903,26 @@ def _get_logprobs(
|
|
for token_id in prompt_tokens[1:]:
|
|
for token_id in prompt_tokens[1:]:
|
|
prompt_logprobs_dict = {
|
|
prompt_logprobs_dict = {
|
|
token_id:
|
|
token_id:
|
|
- batched_logprobs_query_result[query_result_idx].item()
|
|
|
|
|
|
+ (batched_logprobs_query_result[query_result_idx].item(),
|
|
|
|
+ batched_ranks_query_result[query_result_idx].item())
|
|
}
|
|
}
|
|
if num_logprobs > 0:
|
|
if num_logprobs > 0:
|
|
prompt_logprobs_dict.update(
|
|
prompt_logprobs_dict.update(
|
|
- zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
|
|
|
|
- top_logprobs[sample_idx, :num_logprobs].tolist()))
|
|
|
|
|
|
+ zip(
|
|
|
|
+ top_token_ids[sample_idx, :num_logprobs].tolist(),
|
|
|
|
+ zip(
|
|
|
|
+ top_logprobs[
|
|
|
|
+ sample_idx, :num_logprobs].tolist(),
|
|
|
|
+ range(1, num_logprobs + 1))))
|
|
group_prompt_logprobs.append({
|
|
group_prompt_logprobs.append({
|
|
- token_id: Logprob(logprob)
|
|
|
|
- for token_id, logprob in prompt_logprobs_dict.items()
|
|
|
|
|
|
+ token_id: Logprob(*logprob_rank)
|
|
|
|
+ for token_id, logprob_rank in prompt_logprobs_dict.items()
|
|
})
|
|
})
|
|
sample_idx += 1
|
|
sample_idx += 1
|
|
query_result_idx += 1
|
|
query_result_idx += 1
|
|
result_prompt_logprobs.append(group_prompt_logprobs)
|
|
result_prompt_logprobs.append(group_prompt_logprobs)
|
|
else:
|
|
else:
|
|
result_prompt_logprobs.append(None)
|
|
result_prompt_logprobs.append(None)
|
|
-
|
|
|
|
# Sample logprobs
|
|
# Sample logprobs
|
|
num_logprobs = sampling_params.logprobs
|
|
num_logprobs = sampling_params.logprobs
|
|
if num_logprobs is None:
|
|
if num_logprobs is None:
|
|
@@ -926,33 +931,89 @@ def _get_logprobs(
|
|
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
|
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
|
sample_logprobs_dict = {
|
|
sample_logprobs_dict = {
|
|
next_token_id:
|
|
next_token_id:
|
|
- batched_logprobs_query_result[query_result_idx].item()
|
|
|
|
|
|
+ (batched_logprobs_query_result[query_result_idx].item(),
|
|
|
|
+ batched_ranks_query_result[query_result_idx].item())
|
|
}
|
|
}
|
|
query_result_idx += 1
|
|
query_result_idx += 1
|
|
- if num_logprobs > 0:
|
|
|
|
|
|
+ if num_logprobs >= 0:
|
|
sample_logprobs_dict.update(
|
|
sample_logprobs_dict.update(
|
|
zip(
|
|
zip(
|
|
top_token_ids[sample_idx +
|
|
top_token_ids[sample_idx +
|
|
parent_id, :num_logprobs].tolist(),
|
|
parent_id, :num_logprobs].tolist(),
|
|
- top_logprobs[sample_idx +
|
|
|
|
- parent_id, :num_logprobs].tolist()))
|
|
|
|
|
|
+ zip(
|
|
|
|
+ top_logprobs[sample_idx +
|
|
|
|
+ parent_id, :num_logprobs].tolist(),
|
|
|
|
+ range(1, num_logprobs + 1))))
|
|
group_sample_logprobs.append({
|
|
group_sample_logprobs.append({
|
|
- token_id: Logprob(logprob)
|
|
|
|
- for token_id, logprob in sample_logprobs_dict.items()
|
|
|
|
|
|
+ token_id: Logprob(*logprob_rank)
|
|
|
|
+ for token_id, logprob_rank in sample_logprobs_dict.items()
|
|
})
|
|
})
|
|
result_sample_logprobs.append(group_sample_logprobs)
|
|
result_sample_logprobs.append(group_sample_logprobs)
|
|
sample_idx += len(seq_ids)
|
|
sample_idx += len(seq_ids)
|
|
-
|
|
|
|
return result_prompt_logprobs, result_sample_logprobs
|
|
return result_prompt_logprobs, result_sample_logprobs
|
|
|
|
|
|
|
|
|
|
|
|
+def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
|
|
|
+ sample_indices: torch.Tensor,
|
|
|
|
+ greedy_samples: torch.Tensor) -> None:
|
|
|
|
+ """Modify the probability distributions of the greedily-sampled tokens such
|
|
|
|
+ that each sampled token has a "probability" of 1.0. This is required by
|
|
|
|
+ speculative decoding, which depends on the sampling method being encoded
|
|
|
|
+ within the probability distribution for correctness.
|
|
|
|
+ # Why do we only need to do this for greedy sampling?
|
|
|
|
+ Aphrodite's sampler performs the following steps for greedy or multinomial
|
|
|
|
+ (random) sampling:
|
|
|
|
+ 1. Get logits from model.
|
|
|
|
+ 2. Modify logits according to per-sequence sampling parameters.
|
|
|
|
+ - Multiply by temperature, top-k and top-p masking, penalize tokens
|
|
|
|
+ according to their frequency, etc.
|
|
|
|
+ 3. Sample a token.
|
|
|
|
+ - Random sampling simply samples from the modified probability
|
|
|
|
+ distribution.
|
|
|
|
+ - Greedy sampling performs `argmax` to obtain the token with the
|
|
|
|
+ highest likelihood.
|
|
|
|
+
|
|
|
|
+ Ignoring greedy sampling for a moment, we find that the computed probability
|
|
|
|
+ distribution has the following property: we can sample from it independently
|
|
|
|
+ and find that the token sampled by the Sampler has a frequency corresponding
|
|
|
|
+ to how often we see it in our sampling. In other words, for tokens sampled
|
|
|
|
+ with Aphrodite's random SamplingType, the computed probability distribution
|
|
|
|
+ encodes the sampling methodology completely.
|
|
|
|
+ Greedy sampling does not normally have this property. Aphrodite modifies
|
|
|
|
+ logits according to sampling params, then performs `argmax`, then returns
|
|
|
|
+ the sampled token and the computed probability distribution. If we sample
|
|
|
|
+ from the distribution, we'll find the likelihood of the greedily-sampled
|
|
|
|
+ token is not always 1.0.
|
|
|
|
+ Since lossless speculative decoding requires that the sampling methodology
|
|
|
|
+ be encoded within the probability distribution, we are motivated to modify
|
|
|
|
+ the probability distribution such that the sampled token has probability 1
|
|
|
|
+ when speculative decoding is used.
|
|
|
|
+ NOTE: Alternatively, we could use an extremely low temperature to achieve
|
|
|
|
+ greedy sampling using multinomial computation and unite the codepaths. This
|
|
|
|
+ has implications on the overall design of the sampler, e.g. how to record
|
|
|
|
+ accurate logprobs for the user, so this improvement is deferred to later.
|
|
|
|
+ """
|
|
|
|
+ logprobs[sample_indices, :] = -float('inf')
|
|
|
|
+ logprobs[sample_indices, greedy_samples] = 0.0
|
|
|
|
+ probs[sample_indices, :] = 0
|
|
|
|
+ probs[sample_indices, greedy_samples] = 1.0
|
|
|
|
+
|
|
|
|
+
|
|
def _build_sampler_output(
|
|
def _build_sampler_output(
|
|
sample_results: List[Tuple[List[int], List[int]]],
|
|
sample_results: List[Tuple[List[int], List[int]]],
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
prompt_logprobs: List[Optional[PromptLogprobs]],
|
|
prompt_logprobs: List[Optional[PromptLogprobs]],
|
|
sample_logprobs: List[SampleLogprobs],
|
|
sample_logprobs: List[SampleLogprobs],
|
|
output_metadata: OutputMetadata,
|
|
output_metadata: OutputMetadata,
|
|
|
|
+ on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
) -> SamplerOutput:
|
|
) -> SamplerOutput:
|
|
|
|
+ """Construct Python objects with the output of sampling.
|
|
|
|
+ Args:
|
|
|
|
+ on_device_tensors: Tuple containing on-device tensors with the
|
|
|
|
+ probabilities used in sampling and the sampled token ids. This
|
|
|
|
+ allows post-processing without copies to CPU/serialization, e.g. in
|
|
|
|
+ speculative decoding rejection sampling.
|
|
|
|
+ """
|
|
assert sampling_metadata.seq_groups is not None
|
|
assert sampling_metadata.seq_groups is not None
|
|
sampler_output = []
|
|
sampler_output = []
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
@@ -969,7 +1030,17 @@ def _build_sampler_output(
|
|
|
|
|
|
sampler_output.append(
|
|
sampler_output.append(
|
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
|
- return SamplerOutput(outputs=sampler_output)
|
|
|
|
|
|
+ # If not specified, store None values in SamplerOutput.
|
|
|
|
+ if on_device_tensors is not None:
|
|
|
|
+ sampled_token_probs, sampled_token_ids = on_device_tensors
|
|
|
|
+ else:
|
|
|
|
+ sampled_token_probs, sampled_token_ids = (None, None)
|
|
|
|
+
|
|
|
|
+ return SamplerOutput(
|
|
|
|
+ outputs=sampler_output,
|
|
|
|
+ sampled_token_probs=sampled_token_probs,
|
|
|
|
+ sampled_token_ids=sampled_token_ids,
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
def _apply_mirostat_v2(logits: torch.Tensor,
|
|
def _apply_mirostat_v2(logits: torch.Tensor,
|