|
@@ -1,5 +1,5 @@
|
|
|
from dataclasses import dataclass
|
|
|
-from typing import Dict, List, Tuple, Optional
|
|
|
+from typing import Dict, List, Tuple, Optional, TypeVar, Callable
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -15,16 +15,26 @@ class PersistentMetadata:
|
|
|
def __init__(self, metadata: Optional[Dict[int, dict]] = None):
|
|
|
self._metadata: Dict[int, dict] = metadata or {}
|
|
|
|
|
|
- def get(self, seq_id: int) -> dict:
|
|
|
- return self._metadata.get(seq_id, {})
|
|
|
+ def get(self, seq_id: int, key, default=None):
|
|
|
+ return self._metadata.get(seq_id, {}).get(key, default)
|
|
|
|
|
|
|
|
|
-class OutputMetadata(PersistentMetadata):
|
|
|
+class OutputMetadata():
|
|
|
+ """Not symmetrical with PersistentMetadata because the process of
|
|
|
+ sampling can produce unique metadata per sample, per sequence.
|
|
|
+
|
|
|
+ The appropriate conversion would be `output[seq][sample](dict)` to
|
|
|
+ `persist[new_seq_for_sample](dict)`"""
|
|
|
|
|
|
- def add(self, seq_id: int, key, val) -> None:
|
|
|
- if seq_id not in self._metadata:
|
|
|
- self._metadata[seq_id] = {}
|
|
|
- self._metadata[seq_id][key] = val
|
|
|
+ def __init__(self):
|
|
|
+ self._metadata: Dict[int, Dict[int, dict]] = {}
|
|
|
+
|
|
|
+ def add(self, seq_id: int, sample_id: int, key, val) -> None:
|
|
|
+ (self._metadata.setdefault(seq_id, {}).setdefault(sample_id,
|
|
|
+ {})[key]) = val
|
|
|
+
|
|
|
+ def get(self, seq_id: int, sample_id: int) -> dict:
|
|
|
+ return self._metadata.get(seq_id, {}).get(sample_id, {})
|
|
|
|
|
|
|
|
|
class SamplingMetadata:
|
|
@@ -89,9 +99,9 @@ class SamplingTensors:
|
|
|
top_ks: torch.Tensor
|
|
|
top_as: torch.Tensor
|
|
|
min_ps: torch.Tensor
|
|
|
- presence_penalties: torch.Tensor
|
|
|
- frequency_penalties: torch.Tensor
|
|
|
- repetition_penalties: torch.Tensor
|
|
|
+ pres_penalties: torch.Tensor
|
|
|
+ freq_penalties: torch.Tensor
|
|
|
+ rep_penalties: torch.Tensor
|
|
|
tfss: torch.Tensor
|
|
|
eta_cutoffs: torch.Tensor
|
|
|
epsilon_cutoffs: torch.Tensor
|
|
@@ -100,333 +110,158 @@ class SamplingTensors:
|
|
|
miro_etas: torch.Tensor
|
|
|
miro_mus: torch.Tensor
|
|
|
miro_indices: torch.Tensor
|
|
|
- miro_seqids: List[int] # state writeback done CPU side
|
|
|
+ miro_seqids: List[int]
|
|
|
dynatemp_mins: torch.Tensor
|
|
|
dynatemp_maxs: torch.Tensor
|
|
|
dynatemp_exps: torch.Tensor
|
|
|
+ smoothing_indices: torch.Tensor
|
|
|
smoothing_factors: torch.Tensor
|
|
|
smoothing_curves: torch.Tensor
|
|
|
prompt_tokens: torch.Tensor
|
|
|
output_tokens: torch.Tensor
|
|
|
|
|
|
- @classmethod
|
|
|
- def from_sampling_metadata(
|
|
|
- cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
|
|
|
- device: torch.device, dtype: torch.dtype
|
|
|
- ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
- bool, bool, bool, bool, bool]:
|
|
|
- prompt_tokens: List[List[int]] = []
|
|
|
- output_tokens: List[List[int]] = []
|
|
|
- top_ks: List[int] = []
|
|
|
- temperatures: List[float] = []
|
|
|
- top_ps: List[float] = []
|
|
|
- top_as: List[float] = []
|
|
|
- min_ps: List[float] = []
|
|
|
- presence_penalties: List[float] = []
|
|
|
- frequency_penalties: List[float] = []
|
|
|
- repetition_penalties: List[float] = []
|
|
|
- tfss: List[float] = []
|
|
|
- eta_cutoffs: List[float] = []
|
|
|
- epsilon_cutoffs: List[float] = []
|
|
|
- typical_ps: List[float] = []
|
|
|
- miro_taus: List[float] = []
|
|
|
- miro_etas: List[float] = []
|
|
|
- miro_mus: List[float] = []
|
|
|
- miro_indices: List[int] = []
|
|
|
- miro_seqids: List[int] = []
|
|
|
- dynatemp_mins: List[float] = []
|
|
|
- dynatemp_maxs: List[float] = []
|
|
|
- dynatemp_exps: List[float] = []
|
|
|
- smoothing_factors: List[float] = []
|
|
|
- smoothing_curves: List[float] = []
|
|
|
- index = 0 # temporary, needed for building miro_indices
|
|
|
- do_temperatures = False
|
|
|
- do_penalties = False
|
|
|
- do_topks = False
|
|
|
- do_topps = False
|
|
|
- do_topas = False
|
|
|
- do_minps = False
|
|
|
- do_tfss = False
|
|
|
- do_eta_cutoffs = False
|
|
|
- do_epsilon_cutoffs = False
|
|
|
- do_typical_ps = False
|
|
|
- do_quadratic = False
|
|
|
- do_mirostat = False
|
|
|
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
|
- seq_ids, sampling_params = seq_group
|
|
|
- temperature = sampling_params.temperature
|
|
|
- p = sampling_params.presence_penalty
|
|
|
- f = sampling_params.frequency_penalty
|
|
|
- r = sampling_params.repetition_penalty
|
|
|
- top_p = sampling_params.top_p
|
|
|
- # k should not be greater than the vocab size
|
|
|
- top_k = min(sampling_params.top_k, vocab_size)
|
|
|
- top_k = vocab_size if top_k == -1 else top_k
|
|
|
- top_a = sampling_params.top_a
|
|
|
- min_p = sampling_params.min_p
|
|
|
- tfs = sampling_params.tfs
|
|
|
- eta_cutoff = sampling_params.eta_cutoff
|
|
|
- epsilon_cutoff = sampling_params.epsilon_cutoff
|
|
|
- typical_p = sampling_params.typical_p
|
|
|
- miro_tau = sampling_params.mirostat_tau
|
|
|
- miro_eta = sampling_params.mirostat_eta
|
|
|
- dynatemp_min = sampling_params.dynatemp_min
|
|
|
- dynatemp_max = sampling_params.dynatemp_max
|
|
|
- dynatemp_exp = sampling_params.dynatemp_exponent
|
|
|
- smoothing_factor = sampling_params.smoothing_factor
|
|
|
- smoothing_curve = sampling_params.smoothing_curve
|
|
|
-
|
|
|
- if do_temperatures is False and temperature > _SAMPLING_EPS:
|
|
|
- do_temperatures = True
|
|
|
- if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
|
|
- or abs(f) >= _SAMPLING_EPS
|
|
|
- or abs(r - 1.0) >= _SAMPLING_EPS):
|
|
|
- do_penalties = True
|
|
|
- if do_topks is False and top_k != vocab_size:
|
|
|
- do_topks = True
|
|
|
- if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
|
|
|
- do_topps = True
|
|
|
- if do_topas is False and top_a > 0.0:
|
|
|
- do_topas = True
|
|
|
- if do_minps is False and min_p > _SAMPLING_EPS:
|
|
|
- do_minps = True
|
|
|
- if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
|
|
|
- do_tfss = True
|
|
|
- if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
|
|
|
- do_eta_cutoffs = True
|
|
|
- if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
|
|
|
- do_epsilon_cutoffs = True
|
|
|
- if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
|
|
|
- do_typical_ps = True
|
|
|
- if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
|
|
|
- or smoothing_curve > 1.0):
|
|
|
- do_quadratic = True
|
|
|
- if do_mirostat is False and sampling_params.mirostat_mode == 2:
|
|
|
- do_mirostat = True
|
|
|
-
|
|
|
- if (i < sampling_metadata.num_prompts
|
|
|
- and sampling_params.prompt_logprobs is not None):
|
|
|
- # For tokens in the prompt that we only need to get their
|
|
|
- # logprobs
|
|
|
- prompt_len = sampling_metadata.prompt_lens[i]
|
|
|
- index += sampling_metadata.prompt_lens[i] - 1
|
|
|
- temperatures += [temperature] * (prompt_len - 1)
|
|
|
- top_ps += [top_p] * (prompt_len - 1)
|
|
|
- top_ks += [top_k] * (prompt_len - 1)
|
|
|
- top_as += [top_a] * (prompt_len - 1)
|
|
|
- min_ps += [min_p] * (prompt_len - 1)
|
|
|
- presence_penalties += [0] * (prompt_len - 1)
|
|
|
- frequency_penalties += [0] * (prompt_len - 1)
|
|
|
- repetition_penalties += [1] * (prompt_len - 1)
|
|
|
- tfss += [1] * (prompt_len - 1)
|
|
|
- eta_cutoffs += [0] * (prompt_len - 1)
|
|
|
- epsilon_cutoffs += [0] * (prompt_len - 1)
|
|
|
- typical_ps += [1] * (prompt_len - 1)
|
|
|
- dynatemp_mins += [dynatemp_min] * (prompt_len - 1)
|
|
|
- dynatemp_maxs += [dynatemp_max] * (prompt_len - 1)
|
|
|
- dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
|
|
|
- smoothing_factors += [smoothing_factor] * (prompt_len - 1)
|
|
|
- smoothing_curves += [smoothing_curve] * (prompt_len - 1)
|
|
|
- prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
|
|
- output_tokens.extend([] for _ in range(prompt_len - 1))
|
|
|
- for seq_id in seq_ids:
|
|
|
- seq_data = sampling_metadata.seq_data[seq_id]
|
|
|
- prompt_tokens.append(seq_data.prompt_token_ids)
|
|
|
- output_tokens.append(seq_data.output_token_ids)
|
|
|
- temperatures += [temperature] * len(seq_ids)
|
|
|
- top_ps += [top_p] * len(seq_ids)
|
|
|
- top_ks += [top_k] * len(seq_ids)
|
|
|
- top_as += [top_a] * len(seq_ids)
|
|
|
- min_ps += [min_p] * len(seq_ids)
|
|
|
- presence_penalties += [p] * len(seq_ids)
|
|
|
- frequency_penalties += [f] * len(seq_ids)
|
|
|
- repetition_penalties += [r] * len(seq_ids)
|
|
|
- tfss += [tfs] * len(seq_ids)
|
|
|
- eta_cutoffs += [eta_cutoff] * len(seq_ids)
|
|
|
- epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
|
|
|
- typical_ps += [typical_p] * len(seq_ids)
|
|
|
- dynatemp_mins += [dynatemp_min] * len(seq_ids)
|
|
|
- dynatemp_maxs += [dynatemp_max] * len(seq_ids)
|
|
|
- dynatemp_exps += [dynatemp_exp] * len(seq_ids)
|
|
|
- smoothing_factors += [smoothing_factor] * len(seq_ids)
|
|
|
- smoothing_curves += [smoothing_curve] * len(seq_ids)
|
|
|
- if sampling_params.mirostat_mode == 2:
|
|
|
- miro_indices += [(index + i) for i in range(len(seq_ids))]
|
|
|
- miro_seqids += seq_ids
|
|
|
- miro_taus += [miro_tau] * len(seq_ids)
|
|
|
- miro_etas += [miro_eta] * len(seq_ids)
|
|
|
- miro_mus += [
|
|
|
- sampling_metadata.persistent_metadata.get(sid).get(
|
|
|
- "miro_mu", sampling_params.mirostat_tau * 2)
|
|
|
- for sid in seq_ids
|
|
|
- ]
|
|
|
- index += len(seq_ids)
|
|
|
-
|
|
|
- sampling_tensors = SamplingTensors.from_lists(
|
|
|
- temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
|
|
|
- frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
|
|
|
- epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
|
|
|
- dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices,
|
|
|
- miro_seqids, smoothing_factors, smoothing_curves, prompt_tokens,
|
|
|
- output_tokens, vocab_size, device, dtype)
|
|
|
- return (sampling_tensors, do_temperatures, do_penalties, do_topks,
|
|
|
- do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
|
|
|
- do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
|
|
|
+ do_temperatures: bool
|
|
|
+ do_dynatemps: bool
|
|
|
+ do_penalties: bool
|
|
|
+ do_top_ks: bool
|
|
|
+ do_top_ps: bool
|
|
|
+ do_top_as: bool
|
|
|
+ do_min_ps: bool
|
|
|
+ do_tfss: bool
|
|
|
+ do_eta_cutoffs: bool
|
|
|
+ do_epsilon_cutoffs: bool
|
|
|
+ do_typical_ps: bool
|
|
|
+ do_quadratic: bool
|
|
|
+ do_mirostat: bool
|
|
|
|
|
|
@classmethod
|
|
|
- def from_lists(cls, temperatures: List[float], top_ps: List[float],
|
|
|
- top_ks: List[int], top_as: List[float], min_ps: List[float],
|
|
|
- presence_penalties: List[float],
|
|
|
- frequency_penalties: List[float],
|
|
|
- repetition_penalties: List[float], tfss: List[float],
|
|
|
- eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
- typical_ps: List[float], dynatemp_mins: List[float],
|
|
|
- dynatemp_maxs: List[float], dynatemp_exps: List[float],
|
|
|
- miro_taus: List[float], miro_etas: List[float],
|
|
|
- miro_mus: List[float], miro_indices: List[int],
|
|
|
- miro_seqids: List[int], smoothing_factors: List[float],
|
|
|
- smoothing_curves: List[float],
|
|
|
- prompt_tokens: List[List[int]],
|
|
|
- output_tokens: List[List[int]], vocab_size: int,
|
|
|
- device: torch.device,
|
|
|
- dtype: torch.dtype) -> "SamplingTensors":
|
|
|
- # Note that the performance will be very bad without
|
|
|
- # pinned memory.
|
|
|
+ def from_sampling_metadata(cls, sampling_metadata: "SamplingMetadata",
|
|
|
+ vocab_size: int, tgt_device: torch.device,
|
|
|
+ float_dtype: torch.dtype) -> "SamplingTensors":
|
|
|
+ prompt_lens = sampling_metadata.prompt_lens or []
|
|
|
+ groups = sampling_metadata.seq_groups or []
|
|
|
+ seq_data = sampling_metadata.seq_data or {}
|
|
|
+ persistent = sampling_metadata.persistent_metadata
|
|
|
+
|
|
|
+ # Flattened list of (params, sid) matching the logits tensor.
|
|
|
+ # `sid < 0` implies a prompt seq.
|
|
|
+ unrolled_seqs: List[Tuple[SamplingParams, int]] = []
|
|
|
+ group_plens = prompt_lens + [0] * (len(groups) - len(prompt_lens))
|
|
|
+ for (ids, params), prompt_len in zip(groups, group_plens):
|
|
|
+ if prompt_len and params.prompt_logprobs is not None:
|
|
|
+ unrolled_seqs.extend([(params, -1)] * (prompt_len - 1))
|
|
|
+ unrolled_seqs.extend([(params, sid) for sid in ids])
|
|
|
+
|
|
|
+ T = TypeVar('T')
|
|
|
+
|
|
|
+ def _unroll(fn_val: Callable[[SamplingParams], T],
|
|
|
+ prompt: Optional[T] = None) -> List[T]:
|
|
|
+ """`fn_val` for every seq, with an override for prompt seqs."""
|
|
|
+ return [
|
|
|
+ prompt if sid < 0 and prompt is not None else fn_val(p)
|
|
|
+ for p, sid in unrolled_seqs
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _index(fn_mask: Callable[[SamplingParams], bool],
|
|
|
+ prompt: Optional[bool] = None) -> List[int]:
|
|
|
+ """Index for every seq where `fn_mask` is true, with an override
|
|
|
+ for prompt seqs."""
|
|
|
+ return [
|
|
|
+ i for i, (p, sid) in enumerate(unrolled_seqs)
|
|
|
+ if (fn_mask(p) if prompt is None else (
|
|
|
+ prompt if sid < 0 else fn_mask(p)))
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _filter(arr: List[T], indices: List[int]) -> List[T]:
|
|
|
+ """Return only the elements of `arr` accessed by `indices`."""
|
|
|
+ return [arr[i] for i in indices]
|
|
|
+
|
|
|
+ miro_inds = _index(lambda p: p.mirostat_mode == 2, prompt=False)
|
|
|
+ _miro_seqs = _filter(unrolled_seqs, miro_inds)
|
|
|
+
|
|
|
+ quad_inds = _index(lambda p: p.smoothing_factor != 0)
|
|
|
+ _quad_seqs = _filter(unrolled_seqs, quad_inds)
|
|
|
+
|
|
|
+ fvars = { # noqa
|
|
|
+ "temperatures": _unroll(lambda p: p.temperature),
|
|
|
+ "top_ps": _unroll(lambda p: p.top_p),
|
|
|
+ "top_as": _unroll(lambda p: p.top_a),
|
|
|
+ "min_ps": _unroll(lambda p: p.min_p),
|
|
|
+ "tfss": _unroll(lambda p: p.tfs, prompt=1),
|
|
|
+ "eta_cutoffs": _unroll(lambda p: p.eta_cutoff * 1e-4, prompt=0),
|
|
|
+ "epsilon_cutoffs": _unroll(lambda p: p.epsilon_cutoff * 1e-4, 0),
|
|
|
+ "typical_ps": _unroll(lambda p: p.typical_p, prompt=1),
|
|
|
+ "pres_penalties": _unroll(lambda p: p.presence_penalty, prompt=0),
|
|
|
+ "freq_penalties": _unroll(lambda p: p.frequency_penalty, prompt=0),
|
|
|
+ "rep_penalties": _unroll(lambda p: p.repetition_penalty, prompt=1),
|
|
|
+
|
|
|
+ "dynatemp_mins": _unroll(lambda p: p.dynatemp_min),
|
|
|
+ "dynatemp_maxs": _unroll(lambda p: p.dynatemp_max),
|
|
|
+ "dynatemp_exps": _unroll(lambda p: p.dynatemp_exponent),
|
|
|
+
|
|
|
+ "miro_taus": [p.mirostat_tau for p, _ in _miro_seqs],
|
|
|
+ "miro_etas": [p.mirostat_eta for p, _ in _miro_seqs],
|
|
|
+ "miro_mus": [persistent.get(sid, "miro_mu", p.mirostat_tau * 2)
|
|
|
+ for p, sid in _miro_seqs],
|
|
|
+
|
|
|
+ "smoothing_factors": [p.smoothing_factor for p, _ in _quad_seqs],
|
|
|
+ "smoothing_curves": [p.smoothing_curve for p, _ in _quad_seqs],
|
|
|
+ }
|
|
|
+ ivars = { # noqa
|
|
|
+ "top_ks": _unroll(lambda p: vocab_size
|
|
|
+ if p.top_k == -1 else min(p.top_k, vocab_size)),
|
|
|
+ "miro_indices": miro_inds,
|
|
|
+ "smoothing_indices": quad_inds,
|
|
|
+ }
|
|
|
+
|
|
|
+ prompt_tokens = [[] if sid < 0 else seq_data[sid].prompt_token_ids
|
|
|
+ for _, sid in unrolled_seqs]
|
|
|
+ output_tokens = [[] if sid < 0 else seq_data[sid].output_token_ids
|
|
|
+ for _, sid in unrolled_seqs]
|
|
|
+
|
|
|
+ def _unjagged(arrs: List[List[T]], padval: T) -> List[List[T]]:
|
|
|
+ max_len = max(len(arr) for arr in arrs)
|
|
|
+ return [arr + [padval] * (max_len - len(arr)) for arr in arrs]
|
|
|
+
|
|
|
+ # Note that the performance will be very bad without pinned memory.
|
|
|
+ # Pinned memory allows non-blocking transfers to device.
|
|
|
pin_memory = not in_wsl()
|
|
|
- prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
|
|
- prompt_padded_tokens = [
|
|
|
- tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
|
|
- for tokens in prompt_tokens
|
|
|
- ]
|
|
|
- output_max_len = max(len(tokens) for tokens in output_tokens)
|
|
|
- output_padded_tokens = [
|
|
|
- tokens + [vocab_size] * (output_max_len - len(tokens))
|
|
|
- for tokens in output_tokens
|
|
|
- ]
|
|
|
-
|
|
|
- temperatures_t = torch.tensor(temperatures,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- top_ps_t = torch.tensor(top_ps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- top_ks_t = torch.tensor(top_ks,
|
|
|
- device="cpu",
|
|
|
- dtype=torch.int,
|
|
|
- pin_memory=pin_memory)
|
|
|
- top_as_t = torch.tensor(top_as,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- min_ps_t = torch.tensor(min_ps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- presence_penalties_t = torch.tensor(presence_penalties,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- frequency_penalties_t = torch.tensor(frequency_penalties,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- repetition_penalties_t = torch.tensor(repetition_penalties,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- tfss_t = torch.tensor(tfss,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- eta_cutoffs_t = torch.tensor(eta_cutoffs,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- typical_ps_t = torch.tensor(typical_ps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- dynatemp_mins_t = torch.tensor(dynatemp_mins,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- dynatemp_exps_t = torch.tensor(dynatemp_exps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- smoothing_factors_t = torch.tensor(smoothing_factors,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- smoothing_curves_t = torch.tensor(smoothing_curves,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- miro_taus_t = torch.tensor(miro_taus,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- miro_etas_t = torch.tensor(miro_etas,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- miro_mus_t = torch.tensor(miro_mus,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- miro_indices_t = torch.tensor(miro_indices,
|
|
|
- device="cpu",
|
|
|
- dtype=torch.int,
|
|
|
- pin_memory=pin_memory)
|
|
|
- prompt_tensor = torch.tensor(prompt_padded_tokens,
|
|
|
- device=device,
|
|
|
- dtype=torch.long,
|
|
|
- pin_memory=pin_memory)
|
|
|
- output_tensor = torch.tensor(output_padded_tokens,
|
|
|
- device=device,
|
|
|
- dtype=torch.long,
|
|
|
- pin_memory=pin_memory)
|
|
|
- # Because the memory is pinned, we can do non-blocking
|
|
|
- # transfer to device.
|
|
|
+
|
|
|
+ def _tensor(contents: list, dtype) -> torch.Tensor:
|
|
|
+ loc_t = torch.tensor(contents,
|
|
|
+ dtype=dtype,
|
|
|
+ device="cpu",
|
|
|
+ pin_memory=pin_memory)
|
|
|
+ return loc_t.to(device=tgt_device, non_blocking=True)
|
|
|
+
|
|
|
return cls(
|
|
|
- temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
|
- top_ps=top_ps_t.to(device=device, non_blocking=True),
|
|
|
- top_ks=top_ks_t.to(device=device, non_blocking=True),
|
|
|
- top_as=top_as_t.to(device=device, non_blocking=True),
|
|
|
- min_ps=min_ps_t.to(device=device, non_blocking=True),
|
|
|
- presence_penalties=presence_penalties_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- frequency_penalties=frequency_penalties_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- repetition_penalties=repetition_penalties_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- tfss=tfss_t.to(device=device, non_blocking=True),
|
|
|
- eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
|
|
|
- epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
|
|
|
- dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
|
|
|
- dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
|
|
|
- smoothing_factors=smoothing_factors_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- smoothing_curves=smoothing_curves_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- miro_taus=miro_taus_t.to(device=device, non_blocking=True),
|
|
|
- miro_etas=miro_etas_t.to(device=device, non_blocking=True),
|
|
|
- miro_mus=miro_mus_t.to(device=device, non_blocking=True),
|
|
|
- miro_indices=miro_indices_t.to(device=device, non_blocking=True),
|
|
|
- miro_seqids=miro_seqids,
|
|
|
- typical_ps=typical_ps_t.to(device=device, non_blocking=True),
|
|
|
- prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
|
|
|
- output_tokens=output_tensor.to(device=device, non_blocking=True),
|
|
|
+ # Flags and non-tensor fields
|
|
|
+ do_temperatures=any(x != 1 for x in fvars["temperatures"]),
|
|
|
+ do_dynatemps=(any(fvars["dynatemp_mins"])
|
|
|
+ or any(fvars["dynatemp_maxs"])),
|
|
|
+ do_top_ks=any(x != vocab_size for x in ivars["top_ks"]),
|
|
|
+ do_top_ps=any(x != 1 for x in fvars["top_ps"]),
|
|
|
+ do_top_as=any(fvars["top_as"]),
|
|
|
+ do_min_ps=any(fvars["min_ps"]),
|
|
|
+ do_tfss=any(x != 1 for x in fvars["tfss"]),
|
|
|
+ do_eta_cutoffs=any(fvars["eta_cutoffs"]),
|
|
|
+ do_epsilon_cutoffs=any(fvars["epsilon_cutoffs"]),
|
|
|
+ do_typical_ps=any(x != 1 for x in fvars["typical_ps"]),
|
|
|
+ do_penalties=(any(fvars["pres_penalties"])
|
|
|
+ or any(fvars["freq_penalties"])
|
|
|
+ or any(x != 1 for x in fvars["rep_penalties"])),
|
|
|
+ do_quadratic=len(quad_inds) > 0,
|
|
|
+ do_mirostat=len(miro_inds) > 0,
|
|
|
+ miro_seqids=_filter([s for _, s in unrolled_seqs], miro_inds),
|
|
|
+ # Float tensors
|
|
|
+ **{n: _tensor(vals, float_dtype)
|
|
|
+ for n, vals in fvars.items()},
|
|
|
+ # Integer tensors
|
|
|
+ **{n: _tensor(vals, torch.int)
|
|
|
+ for n, vals in ivars.items()},
|
|
|
+ # Token ID tensors
|
|
|
+ prompt_tokens=_tensor(_unjagged(prompt_tokens, vocab_size),
|
|
|
+ torch.long),
|
|
|
+ output_tokens=_tensor(_unjagged(output_tokens, vocab_size),
|
|
|
+ torch.long),
|
|
|
)
|