Procházet zdrojové kódy

Overhauled SamplingTensors construction.
Fix multiple bugs in sampler flags.
Restored the functioning logitproc format.
Fixed sudden NaNs in quadratic smoothing.
Rewrote mirostat to work with seeds and other samplers.
Removed branches from some samplers.

50h100a před 11 měsíci
rodič
revize
7ed57e318d
2 změnil soubory, kde provedl 316 přidání a 498 odebrání
  1. 153 170
      aphrodite/modeling/layers/sampler.py
  2. 163 328
      aphrodite/modeling/sampling_metadata.py

+ 153 - 170
aphrodite/modeling/layers/sampler.py

@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple, Optional
 
 import torch
 import torch.nn as nn
+import math
 
 from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
                                                   OutputMetadata,
@@ -137,49 +138,50 @@ def _perform_sampling(
     logits = _apply_logits_processors(logits, sampling_metadata)
 
     # Prepare sampling tensors with pinned memory to avoid blocking.
-    (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
-     do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
-     do_typical_ps, do_quadratic,
-     do_mirostat) = (SamplingTensors.from_sampling_metadata(
-         sampling_metadata, vocab_size, logits.device, logits.dtype))
+    sampling_tensors = SamplingTensors.from_sampling_metadata(
+        sampling_metadata, vocab_size, logits.device, logits.dtype)
 
-    if do_penalties:
+    if sampling_tensors.do_penalties:
         logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                   sampling_tensors.output_tokens,
-                                  sampling_tensors.presence_penalties,
-                                  sampling_tensors.frequency_penalties,
-                                  sampling_tensors.repetition_penalties)
+                                  sampling_tensors.pres_penalties,
+                                  sampling_tensors.freq_penalties,
+                                  sampling_tensors.rep_penalties)
 
-    if do_temperatures:
+    if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
         logits = _apply_temperature(logits, sampling_tensors.temperatures,
                                     sampling_tensors.dynatemp_mins,
                                     sampling_tensors.dynatemp_maxs,
                                     sampling_tensors.dynatemp_exps)
 
-    if do_topks or do_topps or do_topas or do_minps:
+    if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
+            or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
         logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
                                       sampling_tensors.top_ks,
                                       sampling_tensors.top_as,
                                       sampling_tensors.min_ps)
-    if do_tfss:
+
+    if sampling_tensors.do_tfss:
         logits = _apply_tfs(logits, sampling_tensors.tfss)
-    if do_eta_cutoffs:
+    if sampling_tensors.do_eta_cutoffs:
         logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
-    if do_epsilon_cutoffs:
+    if sampling_tensors.do_epsilon_cutoffs:
         logits = _apply_epsilon_cutoff(logits,
                                        sampling_tensors.epsilon_cutoffs)
-    if do_typical_ps:
+    if sampling_tensors.do_typical_ps:
         logits = _apply_typical_sampling(logits, sampling_tensors.typical_ps)
-    if do_quadratic:
+
+    if sampling_tensors.do_quadratic:
         logits = _apply_quadratic_sampling(logits,
+                                           sampling_tensors.smoothing_indices,
                                            sampling_tensors.smoothing_factors,
                                            sampling_tensors.smoothing_curves)
 
     banned_tokens = _get_custom_token_bans(sampling_metadata)
     assert len(banned_tokens) == logits.shape[0]
     logits = _apply_token_bans(logits, banned_tokens)
-    if do_mirostat:
-        logits = _mirostat(logits, sampling_tensors, output_metadata)
+    if sampling_tensors.do_mirostat:
+        logits = _apply_mirostat_v2(logits, sampling_tensors)
 
     # We use float32 for probabilities and log probabilities.
     # Compute the probabilities.
@@ -190,6 +192,10 @@ def _perform_sampling(
 
     # Sample the next tokens.
     sample_results = _sample(probs, logprobs, sampling_metadata)
+
+    if sampling_tensors.do_mirostat:
+        _mirostat_store_args(logits, sampling_tensors, sample_results,
+                             sampling_metadata, output_metadata)
     # Get the logprobs query results.
     prompt_logprobs, sample_logprobs = _get_logprobs(logprobs,
                                                      sampling_metadata,
@@ -239,53 +245,32 @@ def _get_custom_token_bans(
     return banned_tokens
 
 
-# def _apply_logits_processors(
-#     logits: torch.Tensor,
-#     metadata: SamplingMetadata,
-# ) -> torch.Tensor:
-#     seq_offset = 0
-#     for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
-#         seq_size = len(seq_ids)
-#         output_tokens = []
-#         if (i < metadata.num_prompts
-#                 and sampling_params.prompt_logprobs is not None):
-#             prompt_seqs = metadata.prompt_lens[i] - 1
-#             seq_size += prompt_seqs
-#             output_tokens.extend([[]] * prompt_seqs)
-#         seq_end = seq_offset + seq_size
-
-#         if sampling_params.logits_processors:
-#             output_tokens.extend(metadata.seq_data[sid].output_token_ids
-#                                  for sid in seq_ids)
-#             for proc in sampling_params.logits_processors:
-#                 proc(logits[seq_offset:seq_end], output_tokens)
-
-#         seq_offset = seq_end
-
-#     return logits
-
-
 def _apply_logits_processors(
     logits: torch.Tensor,
-    sampling_metadata: SamplingMetadata,
+    metadata: SamplingMetadata,
 ) -> torch.Tensor:
-    logits_row_idx = 0
-    found_logits_processors = False
-    for seq_ids, sampling_params in sampling_metadata.seq_groups:
-        logits_processors = sampling_params.logits_processors
-        if logits_processors:
-            found_logits_processors = True
-            for seq_id in seq_ids:
-                logits_row = logits[logits_row_idx]
-                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
-                for logits_processor in logits_processors:
-                    logits_row = logits_processor(token_ids, logits_row)
-                logits[logits_row_idx] = logits_row
-                logits_row_idx += 1
-        else:
-            logits_row_idx += len(seq_ids)
-    if found_logits_processors:
-        assert logits_row_idx == logits.shape[0]
+    assert metadata.seq_groups is not None
+    assert metadata.prompt_lens is not None
+    assert metadata.seq_data is not None
+    seq_offset = 0
+    for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
+        seq_size = len(seq_ids)
+        output_tokens = []
+        if (i < metadata.num_prompts
+                and sampling_params.prompt_logprobs is not None):
+            prompt_seqs = metadata.prompt_lens[i] - 1
+            seq_size += prompt_seqs
+            output_tokens.extend([[]] * prompt_seqs)
+        seq_end = seq_offset + seq_size
+
+        if sampling_params.logits_processors:
+            output_tokens.extend(metadata.seq_data[sid].output_token_ids
+                                 for sid in seq_ids)
+            for proc in sampling_params.logits_processors:
+                proc(logits[seq_offset:seq_end], output_tokens)
+
+        seq_offset = seq_end
+
     return logits
 
 
@@ -398,20 +383,19 @@ def _apply_eta_cutoff(
     logits: torch.Tensor,
     eta_cutoff: torch.Tensor,
 ) -> torch.Tensor:
-    eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
-                       device=logits.device) * 1e-4
     shifted_logits = torch.log_softmax(logits, dim=-1)
     probs = shifted_logits.exp()
 
     neg_entropy = (probs * shifted_logits).nansum(dim=-1)
-    eps = torch.min(eta,
-                    torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
+    eps = torch.min(eta_cutoff,
+                    torch.sqrt(eta_cutoff) *
+                    torch.exp(neg_entropy)).unsqueeze(dim=1)
 
     eta_mask = probs < eps
 
-    if torch.all(eta_mask):  # guard against nulling out all the logits
-        topk_prob, _ = torch.max(probs, dim=-1)
-        eta_mask = probs < topk_prob
+    # guard against nulling out all the logits
+    top_idx = torch.argmax(probs, dim=1, keepdim=True)
+    eta_mask.scatter_(dim=1, index=top_idx, value=False)
 
     logits[eta_mask] = -float("inf")
     return logits
@@ -421,16 +405,13 @@ def _apply_epsilon_cutoff(
     logits: torch.Tensor,
     epsilon_cutoff: torch.Tensor,
 ) -> torch.Tensor:
-    eps = torch.tensor(epsilon_cutoff,
-                       dtype=logits.dtype,
-                       device=logits.device).unsqueeze(dim=1)
     probs = logits.softmax(dim=-1)
 
-    eps_mask = probs < (eps * 1e-4)
+    eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
 
-    if torch.all(eps_mask):  # guard against nulling out all the logits
-        topk_prob, _ = torch.max(probs, dim=-1)
-        eps_mask = probs < topk_prob
+    # guard against nulling out all the logits
+    top_idx = torch.argmax(probs, dim=1, keepdim=True)
+    eps_mask.scatter_(dim=1, index=top_idx, value=False)
 
     logits[eps_mask] = -float("inf")
     return logits
@@ -440,7 +421,6 @@ def _apply_typical_sampling(
     logits: torch.Tensor,
     typical_p: torch.Tensor,
 ) -> torch.Tensor:
-    typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
     shifted_logits = torch.log_softmax(logits, dim=-1)
     probs = shifted_logits.exp()
 
@@ -449,7 +429,8 @@ def _apply_typical_sampling(
     surprisal_deviations = (neg_entropy - shifted_logits).abs()
     _, indices = torch.sort(surprisal_deviations)
     reordered_probs = probs.gather(-1, indices)
-    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
+    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
+        dim=1)
 
     min_tokens_to_keep = 1
     # Keep at least min_tokens_to_keep
@@ -493,8 +474,9 @@ def _apply_temperature(
 
 def _apply_quadratic_sampling(
     logits: torch.Tensor,
-    smoothing_factors: torch.Tensor,
-    smoothing_curves: torch.Tensor,
+    indices: torch.Tensor,
+    factors: torch.Tensor,
+    curves: torch.Tensor,
 ) -> torch.Tensor:
     """
     Applies a quadratic transformation to the logits based on the
@@ -508,9 +490,11 @@ def _apply_quadratic_sampling(
 
     params:
         logits (torch.Tensor): The logits to be transformed.
-        smoothing_factors (torch.Tensor): The factors to scale the quadratic
+        indices (torch.Tensor): Indices to project `logits` down to 
+            the other tensor's lengths.
+        factors (torch.Tensor): The factors to scale the quadratic
             term in the transformation.
-        smoothing_curves (torch.Tensor): The factors to scale the cubic term
+        curves (torch.Tensor): The factors to scale the cubic term
             in the transformation.
 
     returns:
@@ -518,20 +502,20 @@ def _apply_quadratic_sampling(
 
     Credits: @kalomaze
     """
-    max_logits = logits.max(dim=-1, keepdim=True).values
-    diff = logits - max_logits
-    smoothing_factors.unsqueeze_(dim=1)
-    smoothing_curves.unsqueeze_(dim=1)
-
-    k = (3 - smoothing_curves) / 2
-    s = (smoothing_curves - 1) / 2
-
-    mask = smoothing_factors > 0
-    mask = mask.flatten()
-    transformed_logits = torch.where(
-        logits != float('-inf'), -(k * smoothing_factors * diff**2) +
-        (s * smoothing_factors * diff**3) + max_logits, logits)
-    logits[mask, :] = transformed_logits[mask, :]
+    factors.unsqueeze_(dim=1)
+    curves.unsqueeze_(dim=1)
+    k = factors * (3 - curves) / 2
+    s = factors * (curves - 1) / 2
+
+    quadlogits = logits[indices]  # project to only relevant logits
+    max_logits = quadlogits.max(dim=-1, keepdim=True).values
+
+    # Construct the delta from each logit to its new value
+    diff = quadlogits - max_logits
+    diff -= diff**2 * (s * diff - k)
+    diff[diff != diff] = 0  # Eliminate NaNs from infs
+
+    logits[indices] -= diff
     return logits
 
 
@@ -539,7 +523,6 @@ def _greedy_sample(
     selected_seq_groups: List[Tuple[List[int], SamplingParams]],
     samples: torch.Tensor,
 ) -> List[Tuple[List[int], List[int]]]:
-    samples = samples.tolist()
     sample_idx = 0
     results = []
     for seq_group in selected_seq_groups:
@@ -548,7 +531,7 @@ def _greedy_sample(
         assert num_parent_seqs == 1, (
             "Greedy sampling should have only one seq.")
         parent_ids = list(range(num_parent_seqs))
-        next_token_ids = [samples[sample_idx]]
+        next_token_ids = [samples[sample_idx].item()]
         results.append((next_token_ids, parent_ids))
         sample_idx += num_parent_seqs
     return results
@@ -671,6 +654,10 @@ def _sample(
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
 ) -> List[Tuple[List[int], List[int]]]:
+    """Returns list of (selected_tokens, parent_seq_ids) tuples
+    corresponding to sampling_metadata.seq_groups."""
+    assert sampling_metadata.seq_groups is not None
+    assert sampling_metadata.categorized_sample_indices is not None
     categorized_seq_group_ids = {t: [] for t in SamplingType}
     categorized_sample_indices = sampling_metadata.categorized_sample_indices
     for i, seq_group in enumerate(sampling_metadata.seq_groups):
@@ -860,92 +847,88 @@ def _build_sampler_output(
     sample_logprobs: List[SampleLogprobs],
     output_metadata: OutputMetadata,
 ) -> SamplerOutput:
+    assert sampling_metadata.seq_groups is not None
     sampler_output = []
     for (seq_group, sample_result, group_prompt_logprobs,
          group_sample_logprobs) in zip(sampling_metadata.seq_groups,
                                        sample_results, prompt_logprobs,
                                        sample_logprobs):
         seq_ids, _ = seq_group
-        next_token_ids, parent_ids = sample_result
-        seq_outputs = []
-        for parent_id, next_token_id, logprobs in zip(parent_ids,
-                                                      next_token_ids,
-                                                      group_sample_logprobs):
-            seq_outputs.append(
-                SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
-                               output_metadata.get(seq_ids[parent_id])))
+        seq_outputs = [
+            SequenceOutput(seq_ids[parent_id], token_id, logprobs,
+                           output_metadata.get(seq_ids[parent_id], idx))
+            for idx, (token_id, parent_id, logprobs) in enumerate(
+                zip(*sample_result, group_sample_logprobs))
+        ]
+
         sampler_output.append(
             SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
     return sampler_output
 
 
-def _miro_store_args(seqids: List[int], mus: List[float],
-                     output_metadata: OutputMetadata) -> None:
-    for sid, mu in zip(seqids,
-                       mus.tolist()):  # tolist might be premature optimization
-        output_metadata.add(sid, "miro_mu", mu)
+def _apply_mirostat_v2(logits: torch.Tensor,
+                       sampling_tensors: SamplingTensors) -> torch.Tensor:
+    # Reduce our view to just the affected logits
+    logit_view = logits[sampling_tensors.miro_indices]
 
+    # Calculate surprise value per token
+    #  Convert nats to bits for compatibility with ooba/kobold parameters.
+    logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
 
-def _apply_mirostat_v2(
-        logits: torch.Tensor,
-        taus: torch.Tensor,  # AKA the targeted surprise
-        etas: torch.Tensor,  # AKA the learning rate
-        mus: torch.
-    Tensor,  # AKA the accumulator that always tries to approach [tau]
-) -> torch.Tensor:
-
-    logit_surprise = torch.softmax(
-        logits, dim=-1).log2_().neg_()  # Calculate surprise value per token
-    # For compatibility with ooba/kobold, done in unit of bits(log base 2)
-    # not nats(ln).
-    # Ideally this would be a log_softmax, for numerical stability and
-    # elegance purposes.
-    # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
-
-    miro_mask = logit_surprise > mus.unsqueeze(
-        dim=-1)  # Mask out "too-surprising" tokens (above mu)
-    mininds = torch.argmin(logit_surprise, dim=-1)
-    miro_mask.scatter_(
-        1, mininds.unsqueeze(dim=-1), False
-    )  # Force at least one outcome to be possible, ideally the most likely one
-
-    logits[miro_mask] = -float("inf")
-
-    probs = torch.softmax(logits, dim=-1,
-                          dtype=logits.dtype)  # Get probs, post-mask
-
-    # NOTE: Mirostat updates its `mu` values based on the sample chosen.
-    # The silly approach here is to just sample it and make the logits one-hot.
-    # This breaks fine grained seeding, but we don't have that yet.
-    # TODO: FIX when it gets added
-    next_token_ids = _multinomial(probs, num_samples=1)
-
-    # Calculation new `mu` values
-    # NOTE: If we can know the logit values of the PREVIOUS iteration,
-    # it should be possible to update `mu` before applying mirostat each
-    # iteration, thus letting us keep _sample as the last thing that happens.
-    picked_surprises = torch.gather(logit_surprise,
-                                    dim=-1,
-                                    index=next_token_ids)
-    eps = picked_surprises.squeeze() - taus
-    mus.sub_(etas * eps)
-
-    logits.fill_(-float("inf"))
-    # This value doesn't actually matter, so long as it's not -inf.
-    # Vectors are now one-hot, after all.
-    logits.scatter_(1, next_token_ids, 1.0)
-    return logits
+    # Mask out "too-surprising" tokens (surprisal > mu)
+    mus = sampling_tensors.miro_mus
+    miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
 
+    # Unmask most-likely logit to guarantee a selection.
+    maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
+    miro_mask.scatter_(dim=1, index=maxinds, value=False)
 
-def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
-              output_metadata: OutputMetadata) -> torch.Tensor:
-    idx = sampling_tensors.miro_indices
-    seqids = sampling_tensors.miro_seqids
-    taus = sampling_tensors.miro_taus
-    etas = sampling_tensors.miro_etas
-    mus = sampling_tensors.miro_mus
+    # Apply logit mask (effectively a top-k filter).
+    logit_view[miro_mask] = -float("inf")
 
-    logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
-                                     mus)  # mus is an i/o param, :vomit:
-    _miro_store_args(seqids, mus, output_metadata)
+    # Project logit changes made to the view onto the original.
+    # I think this step might be redundant.
+    logits[sampling_tensors.miro_indices] = logit_view
     return logits
+
+
+def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
+                         sample_results: List[Tuple[List[int], List[int]]],
+                         sampling_metadata: SamplingMetadata,
+                         output_metadata: OutputMetadata) -> None:
+    """Based on whichever token was finally sampled, we calculate the
+    final surprisal values to update the mus.
+    
+    Because a single sequence can have multiple samples, we must fork
+    the mu accordingly."""
+    assert sampling_metadata.seq_groups is not None
+    seqid_to_tokens = {}
+    seqid_to_indices = {}
+    for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
+                                          sample_results):
+        for idx, (token, parent) in enumerate(zip(toks, parents)):
+            seqid_to_tokens.setdefault(sids[parent], []).append(token)
+            seqid_to_indices.setdefault(sids[parent], []).append(idx)
+
+    seqids = args.miro_seqids
+
+    picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
+                                 device=logits.device,
+                                 dtype=torch.long)
+
+    # Clumsily, we recalculate token surprisals.
+    logits_view = logits[args.miro_indices]
+    picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
+                                   dim=-1,
+                                   index=picked_tokens) / -math.log(2)
+
+    taus = args.miro_taus.unsqueeze(dim=-1)  # AKA target surprisals
+    etas = args.miro_etas.unsqueeze(dim=-1)  # AKA accumulation rates
+    mus = args.miro_mus.unsqueeze(dim=-1)  # AKA surprisal accumulators
+    nu_mus = mus - (picked_surprise - taus) * etas
+
+    # Record updated mu values for use in the next iteration
+    # Note how each mu is split into multiple based on the number of samples.
+    for seqid, seq_mus in zip(seqids, nu_mus):
+        for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
+            output_metadata.add(seqid, sample_idx, "miro_mu", mu)

+ 163 - 328
aphrodite/modeling/sampling_metadata.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import Dict, List, Tuple, Optional
+from typing import Dict, List, Tuple, Optional, TypeVar, Callable
 
 import torch
 
@@ -15,16 +15,26 @@ class PersistentMetadata:
     def __init__(self, metadata: Optional[Dict[int, dict]] = None):
         self._metadata: Dict[int, dict] = metadata or {}
 
-    def get(self, seq_id: int) -> dict:
-        return self._metadata.get(seq_id, {})
+    def get(self, seq_id: int, key, default=None):
+        return self._metadata.get(seq_id, {}).get(key, default)
 
 
-class OutputMetadata(PersistentMetadata):
+class OutputMetadata():
+    """Not symmetrical with PersistentMetadata because the process of
+    sampling can produce unique metadata per sample, per sequence.
+    
+    The appropriate conversion would be `output[seq][sample](dict)` to
+    `persist[new_seq_for_sample](dict)`"""
 
-    def add(self, seq_id: int, key, val) -> None:
-        if seq_id not in self._metadata:
-            self._metadata[seq_id] = {}
-        self._metadata[seq_id][key] = val
+    def __init__(self):
+        self._metadata: Dict[int, Dict[int, dict]] = {}
+
+    def add(self, seq_id: int, sample_id: int, key, val) -> None:
+        (self._metadata.setdefault(seq_id, {}).setdefault(sample_id,
+                                                          {})[key]) = val
+
+    def get(self, seq_id: int, sample_id: int) -> dict:
+        return self._metadata.get(seq_id, {}).get(sample_id, {})
 
 
 class SamplingMetadata:
@@ -89,9 +99,9 @@ class SamplingTensors:
     top_ks: torch.Tensor
     top_as: torch.Tensor
     min_ps: torch.Tensor
-    presence_penalties: torch.Tensor
-    frequency_penalties: torch.Tensor
-    repetition_penalties: torch.Tensor
+    pres_penalties: torch.Tensor
+    freq_penalties: torch.Tensor
+    rep_penalties: torch.Tensor
     tfss: torch.Tensor
     eta_cutoffs: torch.Tensor
     epsilon_cutoffs: torch.Tensor
@@ -100,333 +110,158 @@ class SamplingTensors:
     miro_etas: torch.Tensor
     miro_mus: torch.Tensor
     miro_indices: torch.Tensor
-    miro_seqids: List[int]  # state writeback done CPU side
+    miro_seqids: List[int]
     dynatemp_mins: torch.Tensor
     dynatemp_maxs: torch.Tensor
     dynatemp_exps: torch.Tensor
+    smoothing_indices: torch.Tensor
     smoothing_factors: torch.Tensor
     smoothing_curves: torch.Tensor
     prompt_tokens: torch.Tensor
     output_tokens: torch.Tensor
 
-    @classmethod
-    def from_sampling_metadata(
-        cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
-        device: torch.device, dtype: torch.dtype
-    ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
-               bool, bool, bool, bool, bool]:
-        prompt_tokens: List[List[int]] = []
-        output_tokens: List[List[int]] = []
-        top_ks: List[int] = []
-        temperatures: List[float] = []
-        top_ps: List[float] = []
-        top_as: List[float] = []
-        min_ps: List[float] = []
-        presence_penalties: List[float] = []
-        frequency_penalties: List[float] = []
-        repetition_penalties: List[float] = []
-        tfss: List[float] = []
-        eta_cutoffs: List[float] = []
-        epsilon_cutoffs: List[float] = []
-        typical_ps: List[float] = []
-        miro_taus: List[float] = []
-        miro_etas: List[float] = []
-        miro_mus: List[float] = []
-        miro_indices: List[int] = []
-        miro_seqids: List[int] = []
-        dynatemp_mins: List[float] = []
-        dynatemp_maxs: List[float] = []
-        dynatemp_exps: List[float] = []
-        smoothing_factors: List[float] = []
-        smoothing_curves: List[float] = []
-        index = 0  # temporary, needed for building miro_indices
-        do_temperatures = False
-        do_penalties = False
-        do_topks = False
-        do_topps = False
-        do_topas = False
-        do_minps = False
-        do_tfss = False
-        do_eta_cutoffs = False
-        do_epsilon_cutoffs = False
-        do_typical_ps = False
-        do_quadratic = False
-        do_mirostat = False
-        for i, seq_group in enumerate(sampling_metadata.seq_groups):
-            seq_ids, sampling_params = seq_group
-            temperature = sampling_params.temperature
-            p = sampling_params.presence_penalty
-            f = sampling_params.frequency_penalty
-            r = sampling_params.repetition_penalty
-            top_p = sampling_params.top_p
-            # k should not be greater than the vocab size
-            top_k = min(sampling_params.top_k, vocab_size)
-            top_k = vocab_size if top_k == -1 else top_k
-            top_a = sampling_params.top_a
-            min_p = sampling_params.min_p
-            tfs = sampling_params.tfs
-            eta_cutoff = sampling_params.eta_cutoff
-            epsilon_cutoff = sampling_params.epsilon_cutoff
-            typical_p = sampling_params.typical_p
-            miro_tau = sampling_params.mirostat_tau
-            miro_eta = sampling_params.mirostat_eta
-            dynatemp_min = sampling_params.dynatemp_min
-            dynatemp_max = sampling_params.dynatemp_max
-            dynatemp_exp = sampling_params.dynatemp_exponent
-            smoothing_factor = sampling_params.smoothing_factor
-            smoothing_curve = sampling_params.smoothing_curve
-
-            if do_temperatures is False and temperature > _SAMPLING_EPS:
-                do_temperatures = True
-            if not do_penalties and (abs(p) >= _SAMPLING_EPS
-                                     or abs(f) >= _SAMPLING_EPS
-                                     or abs(r - 1.0) >= _SAMPLING_EPS):
-                do_penalties = True
-            if do_topks is False and top_k != vocab_size:
-                do_topks = True
-            if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
-                do_topps = True
-            if do_topas is False and top_a > 0.0:
-                do_topas = True
-            if do_minps is False and min_p > _SAMPLING_EPS:
-                do_minps = True
-            if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
-                do_tfss = True
-            if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
-                do_eta_cutoffs = True
-            if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
-                do_epsilon_cutoffs = True
-            if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
-                do_typical_ps = True
-            if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
-                                          or smoothing_curve > 1.0):
-                do_quadratic = True
-            if do_mirostat is False and sampling_params.mirostat_mode == 2:
-                do_mirostat = True
-
-            if (i < sampling_metadata.num_prompts
-                    and sampling_params.prompt_logprobs is not None):
-                # For tokens in the prompt that we only need to get their
-                # logprobs
-                prompt_len = sampling_metadata.prompt_lens[i]
-                index += sampling_metadata.prompt_lens[i] - 1
-                temperatures += [temperature] * (prompt_len - 1)
-                top_ps += [top_p] * (prompt_len - 1)
-                top_ks += [top_k] * (prompt_len - 1)
-                top_as += [top_a] * (prompt_len - 1)
-                min_ps += [min_p] * (prompt_len - 1)
-                presence_penalties += [0] * (prompt_len - 1)
-                frequency_penalties += [0] * (prompt_len - 1)
-                repetition_penalties += [1] * (prompt_len - 1)
-                tfss += [1] * (prompt_len - 1)
-                eta_cutoffs += [0] * (prompt_len - 1)
-                epsilon_cutoffs += [0] * (prompt_len - 1)
-                typical_ps += [1] * (prompt_len - 1)
-                dynatemp_mins += [dynatemp_min] * (prompt_len - 1)
-                dynatemp_maxs += [dynatemp_max] * (prompt_len - 1)
-                dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
-                smoothing_factors += [smoothing_factor] * (prompt_len - 1)
-                smoothing_curves += [smoothing_curve] * (prompt_len - 1)
-                prompt_tokens.extend([] for _ in range(prompt_len - 1))
-                output_tokens.extend([] for _ in range(prompt_len - 1))
-            for seq_id in seq_ids:
-                seq_data = sampling_metadata.seq_data[seq_id]
-                prompt_tokens.append(seq_data.prompt_token_ids)
-                output_tokens.append(seq_data.output_token_ids)
-            temperatures += [temperature] * len(seq_ids)
-            top_ps += [top_p] * len(seq_ids)
-            top_ks += [top_k] * len(seq_ids)
-            top_as += [top_a] * len(seq_ids)
-            min_ps += [min_p] * len(seq_ids)
-            presence_penalties += [p] * len(seq_ids)
-            frequency_penalties += [f] * len(seq_ids)
-            repetition_penalties += [r] * len(seq_ids)
-            tfss += [tfs] * len(seq_ids)
-            eta_cutoffs += [eta_cutoff] * len(seq_ids)
-            epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
-            typical_ps += [typical_p] * len(seq_ids)
-            dynatemp_mins += [dynatemp_min] * len(seq_ids)
-            dynatemp_maxs += [dynatemp_max] * len(seq_ids)
-            dynatemp_exps += [dynatemp_exp] * len(seq_ids)
-            smoothing_factors += [smoothing_factor] * len(seq_ids)
-            smoothing_curves += [smoothing_curve] * len(seq_ids)
-            if sampling_params.mirostat_mode == 2:
-                miro_indices += [(index + i) for i in range(len(seq_ids))]
-                miro_seqids += seq_ids
-                miro_taus += [miro_tau] * len(seq_ids)
-                miro_etas += [miro_eta] * len(seq_ids)
-                miro_mus += [
-                    sampling_metadata.persistent_metadata.get(sid).get(
-                        "miro_mu", sampling_params.mirostat_tau * 2)
-                    for sid in seq_ids
-                ]
-            index += len(seq_ids)
-
-        sampling_tensors = SamplingTensors.from_lists(
-            temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
-            frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
-            epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
-            dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices,
-            miro_seqids, smoothing_factors, smoothing_curves, prompt_tokens,
-            output_tokens, vocab_size, device, dtype)
-        return (sampling_tensors, do_temperatures, do_penalties, do_topks,
-                do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
-                do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
+    do_temperatures: bool
+    do_dynatemps: bool
+    do_penalties: bool
+    do_top_ks: bool
+    do_top_ps: bool
+    do_top_as: bool
+    do_min_ps: bool
+    do_tfss: bool
+    do_eta_cutoffs: bool
+    do_epsilon_cutoffs: bool
+    do_typical_ps: bool
+    do_quadratic: bool
+    do_mirostat: bool
 
     @classmethod
-    def from_lists(cls, temperatures: List[float], top_ps: List[float],
-                   top_ks: List[int], top_as: List[float], min_ps: List[float],
-                   presence_penalties: List[float],
-                   frequency_penalties: List[float],
-                   repetition_penalties: List[float], tfss: List[float],
-                   eta_cutoffs: List[float], epsilon_cutoffs: List[float],
-                   typical_ps: List[float], dynatemp_mins: List[float],
-                   dynatemp_maxs: List[float], dynatemp_exps: List[float],
-                   miro_taus: List[float], miro_etas: List[float],
-                   miro_mus: List[float], miro_indices: List[int],
-                   miro_seqids: List[int], smoothing_factors: List[float],
-                   smoothing_curves: List[float],
-                   prompt_tokens: List[List[int]],
-                   output_tokens: List[List[int]], vocab_size: int,
-                   device: torch.device,
-                   dtype: torch.dtype) -> "SamplingTensors":
-        # Note that the performance will be very bad without
-        # pinned memory.
+    def from_sampling_metadata(cls, sampling_metadata: "SamplingMetadata",
+                               vocab_size: int, tgt_device: torch.device,
+                               float_dtype: torch.dtype) -> "SamplingTensors":
+        prompt_lens = sampling_metadata.prompt_lens or []
+        groups = sampling_metadata.seq_groups or []
+        seq_data = sampling_metadata.seq_data or {}
+        persistent = sampling_metadata.persistent_metadata
+
+        # Flattened list of (params, sid) matching the logits tensor.
+        # `sid < 0` implies a prompt seq.
+        unrolled_seqs: List[Tuple[SamplingParams, int]] = []
+        group_plens = prompt_lens + [0] * (len(groups) - len(prompt_lens))
+        for (ids, params), prompt_len in zip(groups, group_plens):
+            if prompt_len and params.prompt_logprobs is not None:
+                unrolled_seqs.extend([(params, -1)] * (prompt_len - 1))
+            unrolled_seqs.extend([(params, sid) for sid in ids])
+
+        T = TypeVar('T')
+
+        def _unroll(fn_val: Callable[[SamplingParams], T],
+                    prompt: Optional[T] = None) -> List[T]:
+            """`fn_val` for every seq, with an override for prompt seqs."""
+            return [
+                prompt if sid < 0 and prompt is not None else fn_val(p)
+                for p, sid in unrolled_seqs
+            ]
+
+        def _index(fn_mask: Callable[[SamplingParams], bool],
+                   prompt: Optional[bool] = None) -> List[int]:
+            """Index for every seq where `fn_mask` is true, with an override
+            for prompt seqs."""
+            return [
+                i for i, (p, sid) in enumerate(unrolled_seqs)
+                if (fn_mask(p) if prompt is None else (
+                    prompt if sid < 0 else fn_mask(p)))
+            ]
+
+        def _filter(arr: List[T], indices: List[int]) -> List[T]:
+            """Return only the elements of `arr` accessed by `indices`."""
+            return [arr[i] for i in indices]
+
+        miro_inds = _index(lambda p: p.mirostat_mode == 2, prompt=False)
+        _miro_seqs = _filter(unrolled_seqs, miro_inds)
+
+        quad_inds = _index(lambda p: p.smoothing_factor != 0)
+        _quad_seqs = _filter(unrolled_seqs, quad_inds)
+
+        fvars = {  # noqa
+            "temperatures": _unroll(lambda p: p.temperature),
+            "top_ps": _unroll(lambda p: p.top_p),
+            "top_as": _unroll(lambda p: p.top_a),
+            "min_ps": _unroll(lambda p: p.min_p),
+            "tfss": _unroll(lambda p: p.tfs, prompt=1),
+            "eta_cutoffs": _unroll(lambda p: p.eta_cutoff * 1e-4, prompt=0),
+            "epsilon_cutoffs": _unroll(lambda p: p.epsilon_cutoff * 1e-4, 0),
+            "typical_ps": _unroll(lambda p: p.typical_p, prompt=1),
+            "pres_penalties": _unroll(lambda p: p.presence_penalty, prompt=0),
+            "freq_penalties": _unroll(lambda p: p.frequency_penalty, prompt=0),
+            "rep_penalties": _unroll(lambda p: p.repetition_penalty, prompt=1),
+
+            "dynatemp_mins": _unroll(lambda p: p.dynatemp_min),
+            "dynatemp_maxs": _unroll(lambda p: p.dynatemp_max),
+            "dynatemp_exps": _unroll(lambda p: p.dynatemp_exponent),
+
+            "miro_taus": [p.mirostat_tau for p, _ in _miro_seqs],
+            "miro_etas": [p.mirostat_eta for p, _ in _miro_seqs],
+            "miro_mus": [persistent.get(sid, "miro_mu", p.mirostat_tau * 2)
+                         for p, sid in _miro_seqs],
+
+            "smoothing_factors": [p.smoothing_factor for p, _ in _quad_seqs],
+            "smoothing_curves": [p.smoothing_curve for p, _ in _quad_seqs],
+        }
+        ivars = {  # noqa
+            "top_ks": _unroll(lambda p: vocab_size
+                              if p.top_k == -1 else min(p.top_k, vocab_size)),
+            "miro_indices": miro_inds,
+            "smoothing_indices": quad_inds,
+        }
+
+        prompt_tokens = [[] if sid < 0 else seq_data[sid].prompt_token_ids
+                         for _, sid in unrolled_seqs]
+        output_tokens = [[] if sid < 0 else seq_data[sid].output_token_ids
+                         for _, sid in unrolled_seqs]
+
+        def _unjagged(arrs: List[List[T]], padval: T) -> List[List[T]]:
+            max_len = max(len(arr) for arr in arrs)
+            return [arr + [padval] * (max_len - len(arr)) for arr in arrs]
+
+        # Note that the performance will be very bad without pinned memory.
+        # Pinned memory allows non-blocking transfers to device.
         pin_memory = not in_wsl()
-        prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
-        prompt_padded_tokens = [
-            tokens + [vocab_size] * (prompt_max_len - len(tokens))
-            for tokens in prompt_tokens
-        ]
-        output_max_len = max(len(tokens) for tokens in output_tokens)
-        output_padded_tokens = [
-            tokens + [vocab_size] * (output_max_len - len(tokens))
-            for tokens in output_tokens
-        ]
-
-        temperatures_t = torch.tensor(temperatures,
-                                      device="cpu",
-                                      dtype=dtype,
-                                      pin_memory=pin_memory)
-        top_ps_t = torch.tensor(top_ps,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        top_ks_t = torch.tensor(top_ks,
-                                device="cpu",
-                                dtype=torch.int,
-                                pin_memory=pin_memory)
-        top_as_t = torch.tensor(top_as,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        min_ps_t = torch.tensor(min_ps,
-                                device="cpu",
-                                dtype=dtype,
-                                pin_memory=pin_memory)
-        presence_penalties_t = torch.tensor(presence_penalties,
-                                            device="cpu",
-                                            dtype=dtype,
-                                            pin_memory=pin_memory)
-        frequency_penalties_t = torch.tensor(frequency_penalties,
-                                             device="cpu",
-                                             dtype=dtype,
-                                             pin_memory=pin_memory)
-        repetition_penalties_t = torch.tensor(repetition_penalties,
-                                              device="cpu",
-                                              dtype=dtype,
-                                              pin_memory=pin_memory)
-        tfss_t = torch.tensor(tfss,
-                              device="cpu",
-                              dtype=dtype,
-                              pin_memory=pin_memory)
-        eta_cutoffs_t = torch.tensor(eta_cutoffs,
-                                     device="cpu",
-                                     dtype=dtype,
-                                     pin_memory=pin_memory)
-        epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
-                                         device="cpu",
-                                         dtype=dtype,
-                                         pin_memory=pin_memory)
-        typical_ps_t = torch.tensor(typical_ps,
-                                    device="cpu",
-                                    dtype=dtype,
-                                    pin_memory=pin_memory)
-        dynatemp_mins_t = torch.tensor(dynatemp_mins,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        dynatemp_exps_t = torch.tensor(dynatemp_exps,
-                                       device="cpu",
-                                       dtype=dtype,
-                                       pin_memory=pin_memory)
-        smoothing_factors_t = torch.tensor(smoothing_factors,
-                                           device="cpu",
-                                           dtype=dtype,
-                                           pin_memory=pin_memory)
-        smoothing_curves_t = torch.tensor(smoothing_curves,
-                                          device="cpu",
-                                          dtype=dtype,
-                                          pin_memory=pin_memory)
-        miro_taus_t = torch.tensor(miro_taus,
-                                   device="cpu",
-                                   dtype=dtype,
-                                   pin_memory=pin_memory)
-        miro_etas_t = torch.tensor(miro_etas,
-                                   device="cpu",
-                                   dtype=dtype,
-                                   pin_memory=pin_memory)
-        miro_mus_t = torch.tensor(miro_mus,
-                                  device="cpu",
-                                  dtype=dtype,
-                                  pin_memory=pin_memory)
-        miro_indices_t = torch.tensor(miro_indices,
-                                      device="cpu",
-                                      dtype=torch.int,
-                                      pin_memory=pin_memory)
-        prompt_tensor = torch.tensor(prompt_padded_tokens,
-                                     device=device,
-                                     dtype=torch.long,
-                                     pin_memory=pin_memory)
-        output_tensor = torch.tensor(output_padded_tokens,
-                                     device=device,
-                                     dtype=torch.long,
-                                     pin_memory=pin_memory)
-        # Because the memory is pinned, we can do non-blocking
-        # transfer to device.
+
+        def _tensor(contents: list, dtype) -> torch.Tensor:
+            loc_t = torch.tensor(contents,
+                                 dtype=dtype,
+                                 device="cpu",
+                                 pin_memory=pin_memory)
+            return loc_t.to(device=tgt_device, non_blocking=True)
+
         return cls(
-            temperatures=temperatures_t.to(device=device, non_blocking=True),
-            top_ps=top_ps_t.to(device=device, non_blocking=True),
-            top_ks=top_ks_t.to(device=device, non_blocking=True),
-            top_as=top_as_t.to(device=device, non_blocking=True),
-            min_ps=min_ps_t.to(device=device, non_blocking=True),
-            presence_penalties=presence_penalties_t.to(device=device,
-                                                       non_blocking=True),
-            frequency_penalties=frequency_penalties_t.to(device=device,
-                                                         non_blocking=True),
-            repetition_penalties=repetition_penalties_t.to(device=device,
-                                                           non_blocking=True),
-            tfss=tfss_t.to(device=device, non_blocking=True),
-            eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
-            epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
-                                                 non_blocking=True),
-            dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
-            dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
-            dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
-            smoothing_factors=smoothing_factors_t.to(device=device,
-                                                     non_blocking=True),
-            smoothing_curves=smoothing_curves_t.to(device=device,
-                                                   non_blocking=True),
-            miro_taus=miro_taus_t.to(device=device, non_blocking=True),
-            miro_etas=miro_etas_t.to(device=device, non_blocking=True),
-            miro_mus=miro_mus_t.to(device=device, non_blocking=True),
-            miro_indices=miro_indices_t.to(device=device, non_blocking=True),
-            miro_seqids=miro_seqids,
-            typical_ps=typical_ps_t.to(device=device, non_blocking=True),
-            prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
-            output_tokens=output_tensor.to(device=device, non_blocking=True),
+            #  Flags and non-tensor fields
+            do_temperatures=any(x != 1 for x in fvars["temperatures"]),
+            do_dynatemps=(any(fvars["dynatemp_mins"])
+                          or any(fvars["dynatemp_maxs"])),
+            do_top_ks=any(x != vocab_size for x in ivars["top_ks"]),
+            do_top_ps=any(x != 1 for x in fvars["top_ps"]),
+            do_top_as=any(fvars["top_as"]),
+            do_min_ps=any(fvars["min_ps"]),
+            do_tfss=any(x != 1 for x in fvars["tfss"]),
+            do_eta_cutoffs=any(fvars["eta_cutoffs"]),
+            do_epsilon_cutoffs=any(fvars["epsilon_cutoffs"]),
+            do_typical_ps=any(x != 1 for x in fvars["typical_ps"]),
+            do_penalties=(any(fvars["pres_penalties"])
+                          or any(fvars["freq_penalties"])
+                          or any(x != 1 for x in fvars["rep_penalties"])),
+            do_quadratic=len(quad_inds) > 0,
+            do_mirostat=len(miro_inds) > 0,
+            miro_seqids=_filter([s for _, s in unrolled_seqs], miro_inds),
+            # Float tensors
+            **{n: _tensor(vals, float_dtype)
+               for n, vals in fvars.items()},
+            # Integer tensors
+            **{n: _tensor(vals, torch.int)
+               for n, vals in ivars.items()},
+            # Token ID tensors
+            prompt_tokens=_tensor(_unjagged(prompt_tokens, vocab_size),
+                                  torch.long),
+            output_tokens=_tensor(_unjagged(output_tokens, vocab_size),
+                                  torch.long),
         )