فهرست منبع

fix: sampler indexing issues in distributed environments (#546)

* attempt 1

* re-add samplers
AlpinDale 8 ماه پیش
والد
کامیت
9ce319b03c
3فایلهای تغییر یافته به همراه202 افزوده شده و 195 حذف شده
  1. 9 2
      aphrodite/common/sampling_params.py
  2. 77 78
      aphrodite/modeling/layers/sampler.py
  3. 116 115
      aphrodite/modeling/sampling_metadata.py

+ 9 - 2
aphrodite/common/sampling_params.py

@@ -5,7 +5,8 @@ from functools import cached_property
 from typing import Any, Callable, Dict, List, Optional, Union
 
 import torch
-from pydantic import conint
+from pydantic import Field
+from typing_extensions import Annotated
 
 _SAMPLING_EPS = 1e-5
 
@@ -170,7 +171,7 @@ class SamplingParams:
         skip_special_tokens: bool = True,
         spaces_between_special_tokens: bool = True,
         logits_processors: Optional[List[LogitsProcessorFunc]] = None,
-        truncate_prompt_tokens: Optional[conint(ge=1)] = None,
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
     ) -> None:
         self.n = n
         self.best_of = best_of if best_of is not None else n
@@ -220,6 +221,12 @@ class SamplingParams:
         self.logits_processors = logits_processors or []
         self.include_stop_str_in_output = include_stop_str_in_output
         self.truncate_prompt_tokens = truncate_prompt_tokens
+        # Number of characters to hold back for stop string evaluation
+        # until sequence is finished.
+        if self.stop and not include_stop_str_in_output:
+            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
+        else:
+            self.output_text_buffer_length = 0
 
         self.default_values = {
             "n": 1,

+ 77 - 78
aphrodite/modeling/layers/sampler.py

@@ -17,14 +17,18 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
 
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
+
     This layer does the following:
     1. Discard the hidden states that are not used for sampling (i.e., all
         tokens except the final one in each prompt).
     2. Compute the logits for the next tokens.
-    3. Apply all the different sampler functions in the specified order.
-    4. Sample the next tokens.
+    3. Apply presence, frequency and repetition penalties.
+    4. Apply temperature scaling.
+    5. Apply top-p and top-k truncation.
+    6. Sample the next tokens.
     Here, each sequence group within the batch can have different sampling
     parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
+
     The structure of the logits tensor is coupled with the seq_groups in
     sampling_metadata. Typically, each sequence in each seq_group has one row in
     logits for the next token to be sampled; however, for a seq_group with a
@@ -52,17 +56,16 @@ class Sampler(nn.Module):
         """
         assert logits is not None
         _, vocab_size = logits.shape
-        # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
-        # have not been generated yet
+
         logits = _apply_min_tokens_penalty(logits, sampling_metadata)
 
         # Prepare sampling tensors with pinned memory to avoid blocking.
-        (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
-         do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-         do_typical_ps,
-         do_quadratic) = (SamplingTensors.from_sampling_metadata(
-             sampling_metadata, vocab_size, logits.device, logits.dtype))
+        (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
+         do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
+         do_quadratic) = SamplingTensors.from_sampling_metadata(
+             sampling_metadata, vocab_size, logits.device, logits.dtype)
 
+        # Apply presence and frequency penalties.
         if do_penalties:
             logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                       sampling_tensors.output_tokens,
@@ -70,18 +73,30 @@ class Sampler(nn.Module):
                                       sampling_tensors.frequency_penalties,
                                       sampling_tensors.repetition_penalties)
 
-        if (do_topks or do_topps or do_topas or do_minps):
-            logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
-                                          sampling_tensors.top_ks,
-                                          sampling_tensors.top_as,
-                                          sampling_tensors.min_ps)
+        # Apply temperature scaling.
+        # Use in-place division to avoid creating a new tensor.
+        logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
+
+        if do_top_p_top_k:
+            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
+                                        sampling_tensors.top_ks)
+
+        if do_top_as:
+            logits = _apply_top_a(logits, sampling_tensors.top_as)
+
+        if do_min_p:
+            logits = _apply_min_p(logits, sampling_tensors.min_ps)
+
         if do_tfss:
             logits = _apply_tfs(logits, sampling_tensors.tfss)
+
         if do_eta_cutoffs:
             logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
+
         if do_epsilon_cutoffs:
             logits = _apply_epsilon_cutoff(logits,
                                            sampling_tensors.epsilon_cutoffs)
+
         if do_typical_ps:
             logits = _apply_typical_sampling(logits,
                                              sampling_tensors.typical_ps)
@@ -91,15 +106,7 @@ class Sampler(nn.Module):
                 logits, sampling_tensors.smoothing_factors,
                 sampling_tensors.smoothing_curves)
 
-        if do_temperatures:
-            logits = _apply_temperature(logits, sampling_tensors.temperatures,
-                                        # sampling_tensors.dynatemp_mins,
-                                        # sampling_tensors.dynatemp_maxs,
-                                        # sampling_tensors.dynatemp_exps
-                                        )
-
         banned_tokens = _get_custom_token_bans(sampling_metadata)
-        # assert len(banned_tokens) == logits.shape[0]
         logits = _apply_token_bans(logits, banned_tokens)
 
         # We use float32 for probabilities and log probabilities.
@@ -117,12 +124,14 @@ class Sampler(nn.Module):
             include_gpu_probs_tensor=self.include_gpu_probs_tensor,
             modify_greedy_probs=self._should_modify_greedy_probs_inplace,
         )
+
         if self.include_gpu_probs_tensor:
             assert maybe_sampled_tokens_tensor is not None
             sampled_tokens_tensor = maybe_sampled_tokens_tensor
             on_device_tensors = (probs, sampled_tokens_tensor)
         else:
             on_device_tensors = None
+
         # Get the logprobs query results.
         prompt_logprobs, sample_logprobs = _get_logprobs(
             logprobs, sampling_metadata, sample_results)
@@ -137,8 +146,10 @@ class Sampler(nn.Module):
         """Whether or not the sampler should modify the probability distribution
         of greedily-sampled tokens such that multinomial sampling would sample
         the greedily-sampled token.
+
         In other words, if True then we set the probability of the greedily-
         sampled token to 1.
+
         This is used by speculative decoding, which requires that the sampling
         method be encoded into the probability distribution.
         """
@@ -258,38 +269,27 @@ def _apply_min_tokens_penalty(
     return logits
 
 
-def _apply_alphabet_soup(
+def _apply_top_k_top_p(
     logits: torch.Tensor,
     p: torch.Tensor,
     k: torch.Tensor,
-    a: torch.Tensor,
-    m: torch.Tensor,
 ) -> torch.Tensor:
-    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
-
-    # Apply top-p, min-p and top-a.
-    probs_sort = logits_sort.softmax(dim=-1)
-    probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
-    min_p_thresholds = probs_sort[:, 0] * m
-    top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
-    threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
-    mask = (probs_sort < threshold.unsqueeze(1)
-            )  # Cull logits below the top-a threshold
-    mask.logical_or_(
-        probs_sum >
-        p.unsqueeze(dim=1))  # Cull logits above the top-p summation threshold
-    mask[:, 0] = False  # Guarantee at least one token is pickable
-    logits_sort[mask] = -float("inf")
+    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
 
     # Apply top-k.
-    # Create a mask for the top-k elements.
-    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
-    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
-    top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
+    top_k_mask = logits_sort.size(1) - k.to(torch.long)
+    # Get all the top_k values.
+    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
+    top_k_mask = logits_sort < top_k_mask
+    logits_sort.masked_fill_(top_k_mask, -float("inf"))
 
-    # Final mask.
-    mask = (mask | top_k_mask)
-    logits_sort.masked_fill_(mask, -float("inf"))
+    # Apply top-p.
+    probs_sort = logits_sort.softmax(dim=-1)
+    probs_sum = probs_sort.cumsum(dim=-1)
+    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
+    # at least one
+    top_p_mask[:, -1] = False
+    logits_sort.masked_fill_(top_p_mask, -float("inf"))
 
     # Re-sort the probabilities.
     src = torch.arange(logits_idx.shape[-1],
@@ -301,6 +301,36 @@ def _apply_alphabet_soup(
     return logits
 
 
+def _apply_min_p(
+    logits: torch.Tensor,
+    min_p: torch.Tensor,
+) -> torch.Tensor:
+    """
+    Adapted from
+    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
+    """
+    probs = torch.softmax(logits, dim=-1)
+    top_probs, _ = probs.max(dim=-1, keepdim=True)
+    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
+    tokens_to_remove = probs < scaled_min_p
+    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
+
+    return logits
+
+
+def _apply_top_a(
+    logits: torch.Tensor,
+    top_a: torch.Tensor,
+) -> torch.Tensor:
+    probs = torch.softmax(logits, dim=-1)
+    top_probs, _ = probs.max(dim=-1, keepdim=True)
+    threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
+    tokens_to_remove = probs < threshold
+    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
+
+    return logits
+
+
 def _apply_tfs(
     logits: torch.Tensor,
     tfs: torch.Tensor,
@@ -393,37 +423,6 @@ def _apply_typical_sampling(
     return logits
 
 
-# pulls double duty for temperature and dynatemp
-def _apply_temperature(
-    logits: torch.Tensor,
-    temperatures: torch.Tensor,
-    # dynatemp_mins: torch.Tensor,
-    # dynatemp_maxs: torch.Tensor,
-    # dynatemp_exps: torch.Tensor,
-) -> torch.Tensor:
-    # dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
-    # dynatemp_mins = dynatemp_mins[dynatemp_mask]
-    # dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
-    # dynatemp_exps = dynatemp_exps[dynatemp_mask]
-    # dynatemp_mins = dynatemp_mins.clamp_(min=0)
-
-    # dynatemp_logits = logits[dynatemp_mask]
-    # dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
-    # dynatemp_probs = dynatemp_shifted_logits.exp()
-    # dynatemp_entropies = -(dynatemp_probs *
-    #                        dynatemp_shifted_logits).nansum(dim=-1)
-    # dynatemp_max_entropies = torch.log_(
-    #     (dynatemp_logits > float("-inf")).sum(dim=-1).float())
-    # normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
-    # dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
-    #             normalized_entropies.pow_(dynatemp_exps))
-
-    # temperatures[dynatemp_mask] = dyn_temp
-    # temperatures[temperatures == 0.0] = 1.0
-    logits.div_(temperatures.unsqueeze_(dim=1))
-    return logits
-
-
 def _apply_quadratic_sampling(
     logits: torch.Tensor,
     smoothing_factor: torch.Tensor,

+ 116 - 115
aphrodite/modeling/sampling_metadata.py

@@ -58,9 +58,11 @@ class SamplingMetadata:
     hidden_states = execute_model(...)
     logits = hidden_states[sampling_metadata.selected_token_indices]
     sample(logits)
+
     def sample(logits):
         # Use categorized_sample_indices for sampling....
     ```
+
     Args:
         seq_groups: List of batched sequence groups.
         selected_token_indices: (num_query_tokens_to_logprob). Indices to find
@@ -141,6 +143,7 @@ def _prepare_seq_groups(
 ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
         SamplingType, List[Tuple[int, int]]], int]:
     """Prepare sequence groups and indices for sampling.
+
     Args:
         seq_group_metadata_list: A list of sequence group to batch.
         prompt_lens: A list of prompt lens per sequence group.
@@ -149,6 +152,7 @@ def _prepare_seq_groups(
             of entire prompt tokens, and it could be shorter.
         device: A device to use for random number generator,
             `SequenceGroupToSample.generator`.
+
     Returns:
         seq_groups: A list of sequence group to sample.
         selected_token_indices: See the definition from `SamplingMetadata`.
@@ -215,6 +219,7 @@ def _prepare_seq_groups(
         """
         This blocks computes selected_token_indices which is used in the
         following way.
+
         hidden_states = model(...)
         logits = hidden_states[selected_token_indices]
         """
@@ -232,6 +237,7 @@ def _prepare_seq_groups(
         """
         This block computes categorized_sample_indices which is used in the
         following way.
+
         hidden_states = model(...)
         logits = hidden_states[selected_token_indices]
         def sample(logits):
@@ -274,6 +280,7 @@ def _prepare_seq_groups(
 @dataclass
 class SamplingTensors:
     """Tensors for sampling."""
+
     temperatures: torch.Tensor
     top_ps: torch.Tensor
     top_ks: torch.Tensor
@@ -286,9 +293,6 @@ class SamplingTensors:
     eta_cutoffs: torch.Tensor
     epsilon_cutoffs: torch.Tensor
     typical_ps: torch.Tensor
-    dynatemp_mins: torch.Tensor
-    dynatemp_maxs: torch.Tensor
-    dynatemp_exps: torch.Tensor
     smoothing_factors: torch.Tensor
     smoothing_curves: torch.Tensor
     sampling_seeds: torch.Tensor
@@ -308,7 +312,12 @@ class SamplingTensors:
         extra_seeds_to_generate: int = 0,
         extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool]:
+               bool, bool]:
+        """
+        extra_seeds_to_generate: extra seeds to generate using the
+            user-defined seed for each sequence.
+        extra_entropy: extra entropy to use when generating seeds.
+        """
         prompt_tokens: List[List[int]] = []
         output_tokens: List[List[int]] = []
         top_ks: List[int] = []
@@ -323,20 +332,15 @@ class SamplingTensors:
         eta_cutoffs: List[float] = []
         epsilon_cutoffs: List[float] = []
         typical_ps: List[float] = []
-        dynatemp_mins: List[float] = []
-        dynatemp_maxs: List[float] = []
-        dynatemp_exps: List[float] = []
         smoothing_factors: List[float] = []
         smoothing_curves: List[float] = []
         sampling_seeds: List[int] = []
         sample_indices: List[int] = []
         prompt_best_of: List[int] = []
-        do_temperatures = False
         do_penalties = False
-        do_topks = False
-        do_topps = False
-        do_topas = False
-        do_minps = False
+        do_top_p_top_k = False
+        do_top_as = False
+        do_min_p = False
         do_tfss = False
         do_eta_cutoffs = False
         do_epsilon_cutoffs = False
@@ -356,38 +360,37 @@ class SamplingTensors:
             f = sampling_params.frequency_penalty
             r = sampling_params.repetition_penalty
             top_p = sampling_params.top_p
-            # k should not be greater than the vocab size
-            top_k = min(sampling_params.top_k, vocab_size)
-            top_k = vocab_size if top_k == -1 else top_k
             top_a = sampling_params.top_a
             min_p = sampling_params.min_p
             tfs = sampling_params.tfs
             eta_cutoff = sampling_params.eta_cutoff
             epsilon_cutoff = sampling_params.epsilon_cutoff
             typical_p = sampling_params.typical_p
-            dynatemp_min = sampling_params.dynatemp_min
-            dynatemp_max = sampling_params.dynatemp_max
-            dynatemp_exp = sampling_params.dynatemp_exponent
             smoothing_factor = sampling_params.smoothing_factor
             smoothing_curve = sampling_params.smoothing_curve
             seed = sampling_params.seed
 
             is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
 
-            if do_temperatures is False and temperature > _SAMPLING_EPS:
-                do_temperatures = True
+            # k should not be greater than the vocab size.
+            top_k = min(sampling_params.top_k, vocab_size)
+            top_k = vocab_size if top_k == -1 else top_k
+            if temperature < _SAMPLING_EPS:
+                # NOTE: Zero temperature means deterministic sampling
+                # (i.e., greedy sampling or beam search).
+                # Set the temperature to 1 to avoid division by zero.
+                temperature = 1.0
+            if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
+                                       or top_k != vocab_size):
+                do_top_p_top_k = True
+            if do_top_as is False and top_a > 0.0:
+                do_top_as = True
+            if not do_min_p and min_p > _SAMPLING_EPS:
+                do_min_p = True
             if not do_penalties and (abs(p) >= _SAMPLING_EPS
                                      or abs(f) >= _SAMPLING_EPS
                                      or abs(r - 1.0) >= _SAMPLING_EPS):
                 do_penalties = True
-            if do_topks is False and top_k != vocab_size:
-                do_topks = True
-            if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
-                do_topps = True
-            if do_topas is False and top_a > 0.0:
-                do_topas = True
-            if do_minps is False and min_p > _SAMPLING_EPS:
-                do_minps = True
             if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
                 do_tfss = True
             if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
@@ -403,8 +406,8 @@ class SamplingTensors:
             is_prompt = seq_group.is_prompt
             if (seq_group.is_prompt
                     and sampling_params.prompt_logprobs is not None):
-                # For tokens in the prompt that we only need to get their
-                # logprobs
+                # For tokens in the prompt that we only need to get
+                # their logprobs
                 subquery_len = seq_group.subquery_len
                 assert subquery_len is not None
                 prefill_len = len(seq_group.prompt_logprob_indices)
@@ -420,9 +423,6 @@ class SamplingTensors:
                 eta_cutoffs += [0] * prefill_len
                 epsilon_cutoffs += [0] * prefill_len
                 typical_ps += [1] * prefill_len
-                dynatemp_mins += [dynatemp_min] * prefill_len
-                dynatemp_maxs += [dynatemp_max] * prefill_len
-                dynatemp_exps += [dynatemp_exp] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
                 prompt_tokens.extend([] for _ in range(prefill_len))
@@ -435,23 +435,20 @@ class SamplingTensors:
                     seq_data = seq_group.seq_data[seq_id]
                     prompt_tokens.append(seq_data.prompt_token_ids)
                     output_tokens.append(seq_data.output_token_ids)
-            temperatures += [temperature] * len(seq_ids)
-            top_ps += [top_p] * len(seq_ids)
-            top_ks += [top_k] * len(seq_ids)
-            top_as += [top_a] * len(seq_ids)
-            min_ps += [min_p] * len(seq_ids)
-            presence_penalties += [p] * len(seq_ids)
-            frequency_penalties += [f] * len(seq_ids)
-            repetition_penalties += [r] * len(seq_ids)
-            tfss += [tfs] * len(seq_ids)
-            eta_cutoffs += [eta_cutoff] * len(seq_ids)
-            epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
-            typical_ps += [typical_p] * len(seq_ids)
-            dynatemp_mins += [dynatemp_min] * len(seq_ids)
-            dynatemp_maxs += [dynatemp_max] * len(seq_ids)
-            dynatemp_exps += [dynatemp_exp] * len(seq_ids)
-            smoothing_factors += [smoothing_factor] * len(seq_ids)
-            smoothing_curves += [smoothing_curve] * len(seq_ids)
+                temperatures += [temperature] * len(seq_ids)
+                top_ps += [top_p] * len(seq_ids)
+                top_ks += [top_k] * len(seq_ids)
+                top_as += [top_a] * len(seq_ids)
+                min_ps += [min_p] * len(seq_ids)
+                presence_penalties += [p] * len(seq_ids)
+                frequency_penalties += [f] * len(seq_ids)
+                repetition_penalties += [r] * len(seq_ids)
+                tfss += [tfs] * len(seq_ids)
+                eta_cutoffs += [eta_cutoff] * len(seq_ids)
+                epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
+                typical_ps += [typical_p] * len(seq_ids)
+                smoothing_factors += [smoothing_factor] * len(seq_ids)
+                smoothing_curves += [smoothing_curve] * len(seq_ids)
 
             if is_prompt:
                 prompt_best_of.append(sampling_params.best_of)
@@ -474,13 +471,12 @@ class SamplingTensors:
         sampling_tensors = SamplingTensors.from_lists(
             temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
             frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
-            epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
-            dynatemp_exps, smoothing_factors, smoothing_curves, sampling_seeds,
-            sample_indices, prompt_tokens, output_tokens, vocab_size,
-            extra_seeds_to_generate, device, dtype)
-        return (sampling_tensors, do_temperatures, do_penalties, do_topks,
-                do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
-                do_epsilon_cutoffs, do_typical_ps, do_quadratic)
+            epsilon_cutoffs, typical_ps, smoothing_factors, smoothing_curves,
+            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
+            vocab_size, extra_seeds_to_generate, device, dtype)
+        return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
+                do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
+                do_typical_ps, do_quadratic)
 
     @classmethod
     def from_lists(cls, temperatures: List[float], top_ps: List[float],
@@ -489,9 +485,7 @@ class SamplingTensors:
                    frequency_penalties: List[float],
                    repetition_penalties: List[float], tfss: List[float],
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
-                   typical_ps: List[float], dynatemp_mins: List[float],
-                   dynatemp_maxs: List[float], dynatemp_exps: List[float],
-                   smoothing_factors: List[float],
+                   typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], sampling_seeds: List[int],
                    sample_indices: List[int], prompt_tokens: List[List[int]],
                    output_tokens: List[List[int]], vocab_size: int,
@@ -513,38 +507,52 @@ class SamplingTensors:
             for tokens in output_tokens
         ]
 
-        temperatures_t = torch.tensor(temperatures,
-                                      device="cpu",
-                                      dtype=dtype,
-                                      pin_memory=pin_memory)
-        top_ps_t = torch.tensor(top_ps,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        top_ks_t = torch.tensor(top_ks,
-                                device="cpu",
-                                dtype=torch.int,
-                                pin_memory=pin_memory)
+        temperatures_t = torch.tensor(
+            temperatures,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        top_ps_t = torch.tensor(
+            top_ps,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
         top_as_t = torch.tensor(top_as,
                                 device="cpu",
                                 dtype=dtype,
                                 pin_memory=pin_memory)
-        min_ps_t = torch.tensor(min_ps,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        presence_penalties_t = torch.tensor(presence_penalties,
-                                            device="cpu",
-                                            dtype=dtype,
-                                            pin_memory=pin_memory)
-        frequency_penalties_t = torch.tensor(frequency_penalties,
-                                             device="cpu",
-                                             dtype=dtype,
-                                             pin_memory=pin_memory)
-        repetition_penalties_t = torch.tensor(repetition_penalties,
-                                              device="cpu",
-                                              dtype=dtype,
-                                              pin_memory=pin_memory)
+        min_ps_t = torch.tensor(
+            min_ps,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        presence_penalties_t = torch.tensor(
+            presence_penalties,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        frequency_penalties_t = torch.tensor(
+            frequency_penalties,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        repetition_penalties_t = torch.tensor(
+            repetition_penalties,
+            device="cpu",
+            dtype=dtype,
+            pin_memory=pin_memory,
+        )
+        top_ks_t = torch.tensor(
+            top_ks,
+            device="cpu",
+            dtype=torch.int,
+            pin_memory=pin_memory,
+        )
         tfss_t = torch.tensor(tfss,
                               device="cpu",
                               dtype=dtype,
@@ -561,18 +569,6 @@ class SamplingTensors:
                                     device="cpu",
                                     dtype=dtype,
                                     pin_memory=pin_memory)
-        dynatemp_mins_t = torch.tensor(dynatemp_mins,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        dynatemp_exps_t = torch.tensor(dynatemp_exps,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
         smoothing_factors_t = torch.tensor(smoothing_factors,
                                            device="cpu",
                                            dtype=dtype,
@@ -581,18 +577,24 @@ class SamplingTensors:
                                           device="cpu",
                                           dtype=dtype,
                                           pin_memory=pin_memory)
-        sample_indices_t = torch.tensor(sample_indices,
-                                        device="cpu",
-                                        dtype=torch.int,
-                                        pin_memory=pin_memory)
-        prompt_tensor = torch.tensor(prompt_padded_tokens,
-                                     device=device,
-                                     dtype=torch.long,
-                                     pin_memory=pin_memory)
-        output_tensor = torch.tensor(output_padded_tokens,
-                                     device=device,
-                                     dtype=torch.long,
-                                     pin_memory=pin_memory)
+        sample_indices_t = torch.tensor(
+            sample_indices,
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
+        prompt_tensor = torch.tensor(
+            prompt_padded_tokens,
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
+        output_tensor = torch.tensor(
+            output_padded_tokens,
+            device="cpu",
+            dtype=torch.long,
+            pin_memory=pin_memory,
+        )
         # need to transpose and make contiguous to
         # copy the tensor correctly.
         # [batch_size, n_seeds] -> [n_seeds, batch_size]
@@ -602,6 +604,7 @@ class SamplingTensors:
             dtype=torch.long,
             pin_memory=pin_memory,
         ).T.contiguous()
+
         # Because the memory is pinned, we can do non-blocking
         # transfer to device.
 
@@ -613,6 +616,7 @@ class SamplingTensors:
         if not extra_seeds_gpu.numel():
             extra_seeds_gpu = None
         sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
+
         return cls(
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             top_ps=top_ps_t.to(device=device, non_blocking=True),
@@ -629,9 +633,6 @@ class SamplingTensors:
             eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
             epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
                                                  non_blocking=True),
-            dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
-            dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
-            dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
             smoothing_factors=smoothing_factors_t.to(device=device,
                                                      non_blocking=True),
             smoothing_curves=smoothing_curves_t.to(device=device,