1
0
Эх сурвалжийг харах

fix: sampler indexing issues in distributed environments (#546)

* attempt 1

* re-add samplers
AlpinDale 8 сар өмнө
parent
commit
9ce319b03c

+ 9 - 2
aphrodite/common/sampling_params.py

@@ -5,7 +5,8 @@ from functools import cached_property
 from typing import Any, Callable, Dict, List, Optional, Union
 from typing import Any, Callable, Dict, List, Optional, Union
 
 
 import torch
 import torch
-from pydantic import conint
+from pydantic import Field
+from typing_extensions import Annotated
 
 
 _SAMPLING_EPS = 1e-5
 _SAMPLING_EPS = 1e-5
 
 
@@ -170,7 +171,7 @@ class SamplingParams:
         skip_special_tokens: bool = True,
         skip_special_tokens: bool = True,
         spaces_between_special_tokens: bool = True,
         spaces_between_special_tokens: bool = True,
         logits_processors: Optional[List[LogitsProcessorFunc]] = None,
         logits_processors: Optional[List[LogitsProcessorFunc]] = None,
-        truncate_prompt_tokens: Optional[conint(ge=1)] = None,
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
     ) -> None:
     ) -> None:
         self.n = n
         self.n = n
         self.best_of = best_of if best_of is not None else n
         self.best_of = best_of if best_of is not None else n
@@ -220,6 +221,12 @@ class SamplingParams:
         self.logits_processors = logits_processors or []
         self.logits_processors = logits_processors or []
         self.include_stop_str_in_output = include_stop_str_in_output
         self.include_stop_str_in_output = include_stop_str_in_output
         self.truncate_prompt_tokens = truncate_prompt_tokens
         self.truncate_prompt_tokens = truncate_prompt_tokens
+        # Number of characters to hold back for stop string evaluation
+        # until sequence is finished.
+        if self.stop and not include_stop_str_in_output:
+            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
+        else:
+            self.output_text_buffer_length = 0
 
 
         self.default_values = {
         self.default_values = {
             "n": 1,
             "n": 1,

+ 77 - 78
aphrodite/modeling/layers/sampler.py

@@ -17,14 +17,18 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
 
 
 class Sampler(nn.Module):
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
     """Samples the next tokens from the model's outputs.
+
     This layer does the following:
     This layer does the following:
     1. Discard the hidden states that are not used for sampling (i.e., all
     1. Discard the hidden states that are not used for sampling (i.e., all
         tokens except the final one in each prompt).
         tokens except the final one in each prompt).
     2. Compute the logits for the next tokens.
     2. Compute the logits for the next tokens.
-    3. Apply all the different sampler functions in the specified order.
-    4. Sample the next tokens.
+    3. Apply presence, frequency and repetition penalties.
+    4. Apply temperature scaling.
+    5. Apply top-p and top-k truncation.
+    6. Sample the next tokens.
     Here, each sequence group within the batch can have different sampling
     Here, each sequence group within the batch can have different sampling
     parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
     parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
+
     The structure of the logits tensor is coupled with the seq_groups in
     The structure of the logits tensor is coupled with the seq_groups in
     sampling_metadata. Typically, each sequence in each seq_group has one row in
     sampling_metadata. Typically, each sequence in each seq_group has one row in
     logits for the next token to be sampled; however, for a seq_group with a
     logits for the next token to be sampled; however, for a seq_group with a
@@ -52,17 +56,16 @@ class Sampler(nn.Module):
         """
         """
         assert logits is not None
         assert logits is not None
         _, vocab_size = logits.shape
         _, vocab_size = logits.shape
-        # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
-        # have not been generated yet
+
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
 
 
         # Prepare sampling tensors with pinned memory to avoid blocking.
         # Prepare sampling tensors with pinned memory to avoid blocking.
-        (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
-         do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-         do_typical_ps,
-         do_quadratic) = (SamplingTensors.from_sampling_metadata(
-             sampling_metadata, vocab_size, logits.device, logits.dtype))
+        (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
+         do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
+         do_quadratic) = SamplingTensors.from_sampling_metadata(
+             sampling_metadata, vocab_size, logits.device, logits.dtype)
 
 
+        # Apply presence and frequency penalties.
         if do_penalties:
         if do_penalties:
             logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
             logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                       sampling_tensors.output_tokens,
                                       sampling_tensors.output_tokens,
@@ -70,18 +73,30 @@ class Sampler(nn.Module):
                                       sampling_tensors.frequency_penalties,
                                       sampling_tensors.frequency_penalties,
                                       sampling_tensors.repetition_penalties)
                                       sampling_tensors.repetition_penalties)
 
 
-        if (do_topks or do_topps or do_topas or do_minps):
-            logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
-                                          sampling_tensors.top_ks,
-                                          sampling_tensors.top_as,
-                                          sampling_tensors.min_ps)
+        # Apply temperature scaling.
+        # Use in-place division to avoid creating a new tensor.
+        logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
+
+        if do_top_p_top_k:
+            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
+                                        sampling_tensors.top_ks)
+
+        if do_top_as:
+            logits = _apply_top_a(logits, sampling_tensors.top_as)
+
+        if do_min_p:
+            logits = _apply_min_p(logits, sampling_tensors.min_ps)
+
         if do_tfss:
         if do_tfss:
             logits = _apply_tfs(logits, sampling_tensors.tfss)
             logits = _apply_tfs(logits, sampling_tensors.tfss)
+
         if do_eta_cutoffs:
         if do_eta_cutoffs:
             logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
             logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
+
         if do_epsilon_cutoffs:
         if do_epsilon_cutoffs:
             logits = _apply_epsilon_cutoff(logits,
             logits = _apply_epsilon_cutoff(logits,
                                            sampling_tensors.epsilon_cutoffs)
                                            sampling_tensors.epsilon_cutoffs)
+
         if do_typical_ps:
         if do_typical_ps:
             logits = _apply_typical_sampling(logits,
             logits = _apply_typical_sampling(logits,
                                              sampling_tensors.typical_ps)
                                              sampling_tensors.typical_ps)
@@ -91,15 +106,7 @@ class Sampler(nn.Module):
                 logits, sampling_tensors.smoothing_factors,
                 logits, sampling_tensors.smoothing_factors,
                 sampling_tensors.smoothing_curves)
                 sampling_tensors.smoothing_curves)
 
 
-        if do_temperatures:
-            logits = _apply_temperature(logits, sampling_tensors.temperatures,
-                                        # sampling_tensors.dynatemp_mins,
-                                        # sampling_tensors.dynatemp_maxs,
-                                        # sampling_tensors.dynatemp_exps
-                                        )
-
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         banned_tokens = _get_custom_token_bans(sampling_metadata)
-        # assert len(banned_tokens) == logits.shape[0]
         logits = _apply_token_bans(logits, banned_tokens)
         logits = _apply_token_bans(logits, banned_tokens)
 
 
         # We use float32 for probabilities and log probabilities.
         # We use float32 for probabilities and log probabilities.
@@ -117,12 +124,14 @@ class Sampler(nn.Module):
             include_gpu_probs_tensor=self.include_gpu_probs_tensor,
             include_gpu_probs_tensor=self.include_gpu_probs_tensor,
             modify_greedy_probs=self._should_modify_greedy_probs_inplace,
             modify_greedy_probs=self._should_modify_greedy_probs_inplace,
         )
         )
+
         if self.include_gpu_probs_tensor:
         if self.include_gpu_probs_tensor:
             assert maybe_sampled_tokens_tensor is not None
             assert maybe_sampled_tokens_tensor is not None
             sampled_tokens_tensor = maybe_sampled_tokens_tensor
             sampled_tokens_tensor = maybe_sampled_tokens_tensor
             on_device_tensors = (probs, sampled_tokens_tensor)
             on_device_tensors = (probs, sampled_tokens_tensor)
         else:
         else:
             on_device_tensors = None
             on_device_tensors = None
+
         # Get the logprobs query results.
         # Get the logprobs query results.
         prompt_logprobs, sample_logprobs = _get_logprobs(
         prompt_logprobs, sample_logprobs = _get_logprobs(
             logprobs, sampling_metadata, sample_results)
             logprobs, sampling_metadata, sample_results)
@@ -137,8 +146,10 @@ class Sampler(nn.Module):
         """Whether or not the sampler should modify the probability distribution
         """Whether or not the sampler should modify the probability distribution
         of greedily-sampled tokens such that multinomial sampling would sample
         of greedily-sampled tokens such that multinomial sampling would sample
         the greedily-sampled token.
         the greedily-sampled token.
+
         In other words, if True then we set the probability of the greedily-
         In other words, if True then we set the probability of the greedily-
         sampled token to 1.
         sampled token to 1.
+
         This is used by speculative decoding, which requires that the sampling
         This is used by speculative decoding, which requires that the sampling
         method be encoded into the probability distribution.
         method be encoded into the probability distribution.
         """
         """
@@ -258,38 +269,27 @@ def _apply_min_tokens_penalty(
     return logits
     return logits
 
 
 
 
-def _apply_alphabet_soup(
+def _apply_top_k_top_p(
     logits: torch.Tensor,
     logits: torch.Tensor,
     p: torch.Tensor,
     p: torch.Tensor,
     k: torch.Tensor,
     k: torch.Tensor,
-    a: torch.Tensor,
-    m: torch.Tensor,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
-    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
-
-    # Apply top-p, min-p and top-a.
-    probs_sort = logits_sort.softmax(dim=-1)
-    probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
-    min_p_thresholds = probs_sort[:, 0] * m
-    top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
-    threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
-    mask = (probs_sort < threshold.unsqueeze(1)
-            )  # Cull logits below the top-a threshold
-    mask.logical_or_(
-        probs_sum >
-        p.unsqueeze(dim=1))  # Cull logits above the top-p summation threshold
-    mask[:, 0] = False  # Guarantee at least one token is pickable
-    logits_sort[mask] = -float("inf")
+    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
 
 
     # Apply top-k.
     # Apply top-k.
-    # Create a mask for the top-k elements.
-    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
-    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
-    top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
+    top_k_mask = logits_sort.size(1) - k.to(torch.long)
+    # Get all the top_k values.
+    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
+    top_k_mask = logits_sort < top_k_mask
+    logits_sort.masked_fill_(top_k_mask, -float("inf"))
 
 
-    # Final mask.
-    mask = (mask | top_k_mask)
-    logits_sort.masked_fill_(mask, -float("inf"))
+    # Apply top-p.
+    probs_sort = logits_sort.softmax(dim=-1)
+    probs_sum = probs_sort.cumsum(dim=-1)
+    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
+    # at least one
+    top_p_mask[:, -1] = False
+    logits_sort.masked_fill_(top_p_mask, -float("inf"))
 
 
     # Re-sort the probabilities.
     # Re-sort the probabilities.
     src = torch.arange(logits_idx.shape[-1],
     src = torch.arange(logits_idx.shape[-1],
@@ -301,6 +301,36 @@ def _apply_alphabet_soup(
     return logits
     return logits
 
 
 
 
+def _apply_min_p(
+    logits: torch.Tensor,
+    min_p: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Adapted from
+    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
+    """
+    probs = torch.softmax(logits, dim=-1)
+    top_probs, _ = probs.max(dim=-1, keepdim=True)
+    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
+    tokens_to_remove = probs < scaled_min_p
+    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
+
+    return logits
+
+
+def _apply_top_a(
+    logits: torch.Tensor,
+    top_a: torch.Tensor,
+) -> torch.Tensor:
+    probs = torch.softmax(logits, dim=-1)
+    top_probs, _ = probs.max(dim=-1, keepdim=True)
+    threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
+    tokens_to_remove = probs < threshold
+    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
+
+    return logits
+
+
 def _apply_tfs(
 def _apply_tfs(
     logits: torch.Tensor,
     logits: torch.Tensor,
     tfs: torch.Tensor,
     tfs: torch.Tensor,
@@ -393,37 +423,6 @@ def _apply_typical_sampling(
     return logits
     return logits
 
 
 
 
-# pulls double duty for temperature and dynatemp
-def _apply_temperature(
-    logits: torch.Tensor,
-    temperatures: torch.Tensor,
-    # dynatemp_mins: torch.Tensor,
-    # dynatemp_maxs: torch.Tensor,
-    # dynatemp_exps: torch.Tensor,
-) -> torch.Tensor:
-    # dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
-    # dynatemp_mins = dynatemp_mins[dynatemp_mask]
-    # dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
-    # dynatemp_exps = dynatemp_exps[dynatemp_mask]
-    # dynatemp_mins = dynatemp_mins.clamp_(min=0)
-
-    # dynatemp_logits = logits[dynatemp_mask]
-    # dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
-    # dynatemp_probs = dynatemp_shifted_logits.exp()
-    # dynatemp_entropies = -(dynatemp_probs *
-    #                        dynatemp_shifted_logits).nansum(dim=-1)
-    # dynatemp_max_entropies = torch.log_(
-    #     (dynatemp_logits > float("-inf")).sum(dim=-1).float())
-    # normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
-    # dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
-    #             normalized_entropies.pow_(dynatemp_exps))
-
-    # temperatures[dynatemp_mask] = dyn_temp
-    # temperatures[temperatures == 0.0] = 1.0
-    logits.div_(temperatures.unsqueeze_(dim=1))
-    return logits
-
-
 def _apply_quadratic_sampling(
 def _apply_quadratic_sampling(
     logits: torch.Tensor,
     logits: torch.Tensor,
     smoothing_factor: torch.Tensor,
     smoothing_factor: torch.Tensor,

+ 116 - 115
aphrodite/modeling/sampling_metadata.py

@@ -58,9 +58,11 @@ class SamplingMetadata:
     hidden_states = execute_model(...)
     hidden_states = execute_model(...)
     logits = hidden_states[sampling_metadata.selected_token_indices]
     logits = hidden_states[sampling_metadata.selected_token_indices]
     sample(logits)
     sample(logits)
+
     def sample(logits):
     def sample(logits):
         # Use categorized_sample_indices for sampling....
         # Use categorized_sample_indices for sampling....
     ```
     ```
+
     Args:
     Args:
         seq_groups: List of batched sequence groups.
         seq_groups: List of batched sequence groups.
         selected_token_indices: (num_query_tokens_to_logprob). Indices to find
         selected_token_indices: (num_query_tokens_to_logprob). Indices to find
@@ -141,6 +143,7 @@ def _prepare_seq_groups(
 ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
 ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
         SamplingType, List[Tuple[int, int]]], int]:
         SamplingType, List[Tuple[int, int]]], int]:
     """Prepare sequence groups and indices for sampling.
     """Prepare sequence groups and indices for sampling.
+
     Args:
     Args:
         seq_group_metadata_list: A list of sequence group to batch.
         seq_group_metadata_list: A list of sequence group to batch.
         prompt_lens: A list of prompt lens per sequence group.
         prompt_lens: A list of prompt lens per sequence group.
@@ -149,6 +152,7 @@ def _prepare_seq_groups(
             of entire prompt tokens, and it could be shorter.
             of entire prompt tokens, and it could be shorter.
         device: A device to use for random number generator,
         device: A device to use for random number generator,
             `SequenceGroupToSample.generator`.
             `SequenceGroupToSample.generator`.
+
     Returns:
     Returns:
         seq_groups: A list of sequence group to sample.
         seq_groups: A list of sequence group to sample.
         selected_token_indices: See the definition from `SamplingMetadata`.
         selected_token_indices: See the definition from `SamplingMetadata`.
@@ -215,6 +219,7 @@ def _prepare_seq_groups(
         """
         """
         This blocks computes selected_token_indices which is used in the
         This blocks computes selected_token_indices which is used in the
         following way.
         following way.
+
         hidden_states = model(...)
         hidden_states = model(...)
         logits = hidden_states[selected_token_indices]
         logits = hidden_states[selected_token_indices]
         """
         """
@@ -232,6 +237,7 @@ def _prepare_seq_groups(
         """
         """
         This block computes categorized_sample_indices which is used in the
         This block computes categorized_sample_indices which is used in the
         following way.
         following way.
+
         hidden_states = model(...)
         hidden_states = model(...)
         logits = hidden_states[selected_token_indices]
         logits = hidden_states[selected_token_indices]
         def sample(logits):
         def sample(logits):
@@ -274,6 +280,7 @@ def _prepare_seq_groups(
 @dataclass
 @dataclass
 class SamplingTensors:
 class SamplingTensors:
     """Tensors for sampling."""
     """Tensors for sampling."""
+
     temperatures: torch.Tensor
     temperatures: torch.Tensor
     top_ps: torch.Tensor
     top_ps: torch.Tensor
     top_ks: torch.Tensor
     top_ks: torch.Tensor
@@ -286,9 +293,6 @@ class SamplingTensors:
     eta_cutoffs: torch.Tensor
     eta_cutoffs: torch.Tensor
     epsilon_cutoffs: torch.Tensor
     epsilon_cutoffs: torch.Tensor
     typical_ps: torch.Tensor
     typical_ps: torch.Tensor
-    dynatemp_mins: torch.Tensor
-    dynatemp_maxs: torch.Tensor
-    dynatemp_exps: torch.Tensor
     smoothing_factors: torch.Tensor
     smoothing_factors: torch.Tensor
     smoothing_curves: torch.Tensor
     smoothing_curves: torch.Tensor
     sampling_seeds: torch.Tensor
     sampling_seeds: torch.Tensor
@@ -308,7 +312,12 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool]:
+               bool, bool]:
+        """
+        extra_seeds_to_generate: extra seeds to generate using the
+            user-defined seed for each sequence.
+        extra_entropy: extra entropy to use when generating seeds.
+        """
         prompt_tokens: List[List[int]] = []
         prompt_tokens: List[List[int]] = []
         output_tokens: List[List[int]] = []
         output_tokens: List[List[int]] = []
         top_ks: List[int] = []
         top_ks: List[int] = []
@@ -323,20 +332,15 @@ class SamplingTensors:
         eta_cutoffs: List[float] = []
         eta_cutoffs: List[float] = []
         epsilon_cutoffs: List[float] = []
         epsilon_cutoffs: List[float] = []
         typical_ps: List[float] = []
         typical_ps: List[float] = []
-        dynatemp_mins: List[float] = []
-        dynatemp_maxs: List[float] = []
-        dynatemp_exps: List[float] = []
         smoothing_factors: List[float] = []
         smoothing_factors: List[float] = []
         smoothing_curves: List[float] = []
         smoothing_curves: List[float] = []
         sampling_seeds: List[int] = []
         sampling_seeds: List[int] = []
         sample_indices: List[int] = []
         sample_indices: List[int] = []
         prompt_best_of: List[int] = []
         prompt_best_of: List[int] = []
-        do_temperatures = False
         do_penalties = False
         do_penalties = False
-        do_topks = False
-        do_topps = False
-        do_topas = False
-        do_minps = False
+        do_top_p_top_k = False
+        do_top_as = False
+        do_min_p = False
         do_tfss = False
         do_tfss = False
         do_eta_cutoffs = False
         do_eta_cutoffs = False
         do_epsilon_cutoffs = False
         do_epsilon_cutoffs = False
@@ -356,38 +360,37 @@ class SamplingTensors:
             f = sampling_params.frequency_penalty
             f = sampling_params.frequency_penalty
             r = sampling_params.repetition_penalty
             r = sampling_params.repetition_penalty
             top_p = sampling_params.top_p
             top_p = sampling_params.top_p
-            # k should not be greater than the vocab size
-            top_k = min(sampling_params.top_k, vocab_size)
-            top_k = vocab_size if top_k == -1 else top_k
             top_a = sampling_params.top_a
             top_a = sampling_params.top_a
             min_p = sampling_params.min_p
             min_p = sampling_params.min_p
             tfs = sampling_params.tfs
             tfs = sampling_params.tfs
             eta_cutoff = sampling_params.eta_cutoff
             eta_cutoff = sampling_params.eta_cutoff
             epsilon_cutoff = sampling_params.epsilon_cutoff
             epsilon_cutoff = sampling_params.epsilon_cutoff
             typical_p = sampling_params.typical_p
             typical_p = sampling_params.typical_p
-            dynatemp_min = sampling_params.dynatemp_min
-            dynatemp_max = sampling_params.dynatemp_max
-            dynatemp_exp = sampling_params.dynatemp_exponent
             smoothing_factor = sampling_params.smoothing_factor
             smoothing_factor = sampling_params.smoothing_factor
             smoothing_curve = sampling_params.smoothing_curve
             smoothing_curve = sampling_params.smoothing_curve
             seed = sampling_params.seed
             seed = sampling_params.seed
 
 
             is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
             is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
 
 
-            if do_temperatures is False and temperature > _SAMPLING_EPS:
-                do_temperatures = True
+            # k should not be greater than the vocab size.
+            top_k = min(sampling_params.top_k, vocab_size)
+            top_k = vocab_size if top_k == -1 else top_k
+            if temperature < _SAMPLING_EPS:
+                # NOTE: Zero temperature means deterministic sampling
+                # (i.e., greedy sampling or beam search).
+                # Set the temperature to 1 to avoid division by zero.
+                temperature = 1.0
+            if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
+                                       or top_k != vocab_size):
+                do_top_p_top_k = True
+            if do_top_as is False and top_a > 0.0:
+                do_top_as = True
+            if not do_min_p and min_p > _SAMPLING_EPS:
+                do_min_p = True
             if not do_penalties and (abs(p) >= _SAMPLING_EPS
             if not do_penalties and (abs(p) >= _SAMPLING_EPS
                                      or abs(f) >= _SAMPLING_EPS
                                      or abs(f) >= _SAMPLING_EPS
                                      or abs(r - 1.0) >= _SAMPLING_EPS):
                                      or abs(r - 1.0) >= _SAMPLING_EPS):
                 do_penalties = True
                 do_penalties = True
-            if do_topks is False and top_k != vocab_size:
-                do_topks = True
-            if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
-                do_topps = True
-            if do_topas is False and top_a > 0.0:
-                do_topas = True
-            if do_minps is False and min_p > _SAMPLING_EPS:
-                do_minps = True
             if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
             if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
                 do_tfss = True
                 do_tfss = True
             if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
             if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
@@ -403,8 +406,8 @@ class SamplingTensors:
             is_prompt = seq_group.is_prompt
             is_prompt = seq_group.is_prompt
             if (seq_group.is_prompt
             if (seq_group.is_prompt
                     and sampling_params.prompt_logprobs is not None):
                     and sampling_params.prompt_logprobs is not None):
-                # For tokens in the prompt that we only need to get their
-                # logprobs
+                # For tokens in the prompt that we only need to get
+                # their logprobs
                 subquery_len = seq_group.subquery_len
                 subquery_len = seq_group.subquery_len
                 assert subquery_len is not None
                 assert subquery_len is not None
                 prefill_len = len(seq_group.prompt_logprob_indices)
                 prefill_len = len(seq_group.prompt_logprob_indices)
@@ -420,9 +423,6 @@ class SamplingTensors:
                 eta_cutoffs += [0] * prefill_len
                 eta_cutoffs += [0] * prefill_len
                 epsilon_cutoffs += [0] * prefill_len
                 epsilon_cutoffs += [0] * prefill_len
                 typical_ps += [1] * prefill_len
                 typical_ps += [1] * prefill_len
-                dynatemp_mins += [dynatemp_min] * prefill_len
-                dynatemp_maxs += [dynatemp_max] * prefill_len
-                dynatemp_exps += [dynatemp_exp] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
                 prompt_tokens.extend([] for _ in range(prefill_len))
                 prompt_tokens.extend([] for _ in range(prefill_len))
@@ -435,23 +435,20 @@ class SamplingTensors:
                     seq_data = seq_group.seq_data[seq_id]
                     seq_data = seq_group.seq_data[seq_id]
                     prompt_tokens.append(seq_data.prompt_token_ids)
                     prompt_tokens.append(seq_data.prompt_token_ids)
                     output_tokens.append(seq_data.output_token_ids)
                     output_tokens.append(seq_data.output_token_ids)
-            temperatures += [temperature] * len(seq_ids)
-            top_ps += [top_p] * len(seq_ids)
-            top_ks += [top_k] * len(seq_ids)
-            top_as += [top_a] * len(seq_ids)
-            min_ps += [min_p] * len(seq_ids)
-            presence_penalties += [p] * len(seq_ids)
-            frequency_penalties += [f] * len(seq_ids)
-            repetition_penalties += [r] * len(seq_ids)
-            tfss += [tfs] * len(seq_ids)
-            eta_cutoffs += [eta_cutoff] * len(seq_ids)
-            epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
-            typical_ps += [typical_p] * len(seq_ids)
-            dynatemp_mins += [dynatemp_min] * len(seq_ids)
-            dynatemp_maxs += [dynatemp_max] * len(seq_ids)
-            dynatemp_exps += [dynatemp_exp] * len(seq_ids)
-            smoothing_factors += [smoothing_factor] * len(seq_ids)
-            smoothing_curves += [smoothing_curve] * len(seq_ids)
+                temperatures += [temperature] * len(seq_ids)
+                top_ps += [top_p] * len(seq_ids)
+                top_ks += [top_k] * len(seq_ids)
+                top_as += [top_a] * len(seq_ids)
+                min_ps += [min_p] * len(seq_ids)
+                presence_penalties += [p] * len(seq_ids)
+                frequency_penalties += [f] * len(seq_ids)
+                repetition_penalties += [r] * len(seq_ids)
+                tfss += [tfs] * len(seq_ids)
+                eta_cutoffs += [eta_cutoff] * len(seq_ids)
+                epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
+                typical_ps += [typical_p] * len(seq_ids)
+                smoothing_factors += [smoothing_factor] * len(seq_ids)
+                smoothing_curves += [smoothing_curve] * len(seq_ids)
 
 
             if is_prompt:
             if is_prompt:
                 prompt_best_of.append(sampling_params.best_of)
                 prompt_best_of.append(sampling_params.best_of)
@@ -474,13 +471,12 @@ class SamplingTensors:
         sampling_tensors = SamplingTensors.from_lists(
         sampling_tensors = SamplingTensors.from_lists(
             temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
             temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
             frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
             frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
-            epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
-            dynatemp_exps, smoothing_factors, smoothing_curves, sampling_seeds,
-            sample_indices, prompt_tokens, output_tokens, vocab_size,
-            extra_seeds_to_generate, device, dtype)
-        return (sampling_tensors, do_temperatures, do_penalties, do_topks,
-                do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
-                do_epsilon_cutoffs, do_typical_ps, do_quadratic)
+            epsilon_cutoffs, typical_ps, smoothing_factors, smoothing_curves,
+            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
+            vocab_size, extra_seeds_to_generate, device, dtype)
+        return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
+                do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
+                do_typical_ps, do_quadratic)
 
 
     @classmethod
     @classmethod
     def from_lists(cls, temperatures: List[float], top_ps: List[float],
     def from_lists(cls, temperatures: List[float], top_ps: List[float],
@@ -489,9 +485,7 @@ class SamplingTensors:
                    frequency_penalties: List[float],
                    frequency_penalties: List[float],
                    repetition_penalties: List[float], tfss: List[float],
                    repetition_penalties: List[float], tfss: List[float],
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
-                   typical_ps: List[float], dynatemp_mins: List[float],
-                   dynatemp_maxs: List[float], dynatemp_exps: List[float],
-                   smoothing_factors: List[float],
+                   typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], sampling_seeds: List[int],
                    smoothing_curves: List[float], sampling_seeds: List[int],
                    sample_indices: List[int], prompt_tokens: List[List[int]],
                    sample_indices: List[int], prompt_tokens: List[List[int]],
                    output_tokens: List[List[int]], vocab_size: int,
                    output_tokens: List[List[int]], vocab_size: int,
@@ -513,38 +507,52 @@ class SamplingTensors:
             for tokens in output_tokens
             for tokens in output_tokens
         ]
         ]
 
 
-        temperatures_t = torch.tensor(temperatures,
-                                      device="cpu",
-                                      dtype=dtype,
-                                      pin_memory=pin_memory)
-        top_ps_t = torch.tensor(top_ps,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        top_ks_t = torch.tensor(top_ks,
-                                device="cpu",
-                                dtype=torch.int,
-                                pin_memory=pin_memory)
+        temperatures_t = torch.tensor(
+            temperatures,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        top_ps_t = torch.tensor(
+            top_ps,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
         top_as_t = torch.tensor(top_as,
         top_as_t = torch.tensor(top_as,
                                 device="cpu",
                                 device="cpu",
                                 dtype=dtype,
                                 dtype=dtype,
                                 pin_memory=pin_memory)
                                 pin_memory=pin_memory)
-        min_ps_t = torch.tensor(min_ps,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        presence_penalties_t = torch.tensor(presence_penalties,
-                                            device="cpu",
-                                            dtype=dtype,
-                                            pin_memory=pin_memory)
-        frequency_penalties_t = torch.tensor(frequency_penalties,
-                                             device="cpu",
-                                             dtype=dtype,
-                                             pin_memory=pin_memory)
-        repetition_penalties_t = torch.tensor(repetition_penalties,
-                                              device="cpu",
-                                              dtype=dtype,
-                                              pin_memory=pin_memory)
+        min_ps_t = torch.tensor(
+            min_ps,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        presence_penalties_t = torch.tensor(
+            presence_penalties,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        frequency_penalties_t = torch.tensor(
+            frequency_penalties,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        repetition_penalties_t = torch.tensor(
+            repetition_penalties,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        top_ks_t = torch.tensor(
+            top_ks,
+            device="cpu",
+            dtype=torch.int,
+            pin_memory=pin_memory,
+        )
         tfss_t = torch.tensor(tfss,
         tfss_t = torch.tensor(tfss,
                               device="cpu",
                               device="cpu",
                               dtype=dtype,
                               dtype=dtype,
@@ -561,18 +569,6 @@ class SamplingTensors:
                                     device="cpu",
                                     device="cpu",
                                     dtype=dtype,
                                     dtype=dtype,
                                     pin_memory=pin_memory)
                                     pin_memory=pin_memory)
-        dynatemp_mins_t = torch.tensor(dynatemp_mins,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        dynatemp_exps_t = torch.tensor(dynatemp_exps,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
         smoothing_factors_t = torch.tensor(smoothing_factors,
         smoothing_factors_t = torch.tensor(smoothing_factors,
                                            device="cpu",
                                            device="cpu",
                                            dtype=dtype,
                                            dtype=dtype,
@@ -581,18 +577,24 @@ class SamplingTensors:
                                           device="cpu",
                                           device="cpu",
                                           dtype=dtype,
                                           dtype=dtype,
                                           pin_memory=pin_memory)
                                           pin_memory=pin_memory)
-        sample_indices_t = torch.tensor(sample_indices,
-                                        device="cpu",
-                                        dtype=torch.int,
-                                        pin_memory=pin_memory)
-        prompt_tensor = torch.tensor(prompt_padded_tokens,
-                                     device=device,
-                                     dtype=torch.long,
-                                     pin_memory=pin_memory)
-        output_tensor = torch.tensor(output_padded_tokens,
-                                     device=device,
-                                     dtype=torch.long,
-                                     pin_memory=pin_memory)
+        sample_indices_t = torch.tensor(
+            sample_indices,
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
+        prompt_tensor = torch.tensor(
+            prompt_padded_tokens,
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
+        output_tensor = torch.tensor(
+            output_padded_tokens,
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
         # need to transpose and make contiguous to
         # need to transpose and make contiguous to
         # copy the tensor correctly.
         # copy the tensor correctly.
         # [batch_size, n_seeds] -> [n_seeds, batch_size]
         # [batch_size, n_seeds] -> [n_seeds, batch_size]
@@ -602,6 +604,7 @@ class SamplingTensors:
             dtype=torch.long,
             dtype=torch.long,
             pin_memory=pin_memory,
             pin_memory=pin_memory,
         ).T.contiguous()
         ).T.contiguous()
+
         # Because the memory is pinned, we can do non-blocking
         # Because the memory is pinned, we can do non-blocking
         # transfer to device.
         # transfer to device.
 
 
@@ -613,6 +616,7 @@ class SamplingTensors:
         if not extra_seeds_gpu.numel():
         if not extra_seeds_gpu.numel():
             extra_seeds_gpu = None
             extra_seeds_gpu = None
         sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
         sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
+
         return cls(
         return cls(
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             top_ps=top_ps_t.to(device=device, non_blocking=True),
             top_ps=top_ps_t.to(device=device, non_blocking=True),
@@ -629,9 +633,6 @@ class SamplingTensors:
             eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
             eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
             epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
             epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
                                                  non_blocking=True),
                                                  non_blocking=True),
-            dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
-            dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
-            dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
             smoothing_factors=smoothing_factors_t.to(device=device,
             smoothing_factors=smoothing_factors_t.to(device=device,
                                                      non_blocking=True),
                                                      non_blocking=True),
             smoothing_curves=smoothing_curves_t.to(device=device,
             smoothing_curves=smoothing_curves_t.to(device=device,