|
@@ -393,6 +393,7 @@ class SamplingTensors:
|
|
|
dry_bases: torch.Tensor
|
|
|
dry_allowed_lengths: torch.Tensor
|
|
|
dry_sequence_breaker_ids: torch.Tensor
|
|
|
+ skews: torch.Tensor
|
|
|
sampling_seeds: torch.Tensor
|
|
|
sample_indices: torch.Tensor
|
|
|
extra_seeds: Optional[torch.Tensor]
|
|
@@ -410,7 +411,7 @@ class SamplingTensors:
|
|
|
extra_seeds_to_generate: int = 0,
|
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
- bool, bool, bool, bool, bool, bool, bool, bool]:
|
|
|
+ bool, bool, bool, bool, bool, bool, bool, bool, bool]:
|
|
|
"""
|
|
|
extra_seeds_to_generate: extra seeds to generate using the
|
|
|
user-defined seed for each sequence.
|
|
@@ -446,6 +447,7 @@ class SamplingTensors:
|
|
|
dry_bases: List[float] = []
|
|
|
dry_allowed_lengths: List[int] = []
|
|
|
dry_sequence_breaker_ids: List[List[int]] = []
|
|
|
+ skews: List[float] = []
|
|
|
|
|
|
do_penalties = False
|
|
|
do_no_repeat_ngrams = False
|
|
@@ -461,6 +463,7 @@ class SamplingTensors:
|
|
|
do_xtc = False
|
|
|
do_nsigmas = False
|
|
|
do_dry = False
|
|
|
+ do_skews = False
|
|
|
do_temp_last = False
|
|
|
|
|
|
if _USE_TRITON_SAMPLER:
|
|
@@ -506,6 +509,7 @@ class SamplingTensors:
|
|
|
do_xtc |= params.xtc_probability > _SAMPLING_EPS
|
|
|
do_nsigmas |= params.nsigma > _SAMPLING_EPS
|
|
|
do_dry |= params.dry_multiplier > _SAMPLING_EPS
|
|
|
+ do_skews |= abs(params.skew) > _SAMPLING_EPS
|
|
|
|
|
|
do_temp_last |= params.temperature_last
|
|
|
|
|
@@ -548,6 +552,7 @@ class SamplingTensors:
|
|
|
dry_allowed_lengths += [params.dry_allowed_length] * n_seqs
|
|
|
dry_sequence_breaker_ids += (
|
|
|
[params.dry_sequence_breaker_ids] * n_seqs)
|
|
|
+ skews += [params.skew] * n_seqs
|
|
|
|
|
|
if _USE_TRITON_SAMPLER:
|
|
|
if is_prompt:
|
|
@@ -596,13 +601,14 @@ class SamplingTensors:
|
|
|
no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
|
|
|
typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
|
|
|
xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
|
|
|
- dry_allowed_lengths, dry_sequence_breaker_ids, sampling_seeds,
|
|
|
- sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
|
- extra_seeds_to_generate, device, dtype)
|
|
|
+ dry_allowed_lengths, dry_sequence_breaker_ids, skews,
|
|
|
+ sampling_seeds, sample_indices, prompt_tokens, output_tokens,
|
|
|
+ vocab_size, extra_seeds_to_generate, device, dtype)
|
|
|
return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
|
|
|
do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
|
|
|
do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
|
|
|
- do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last)
|
|
|
+ do_quadratic, do_xtc, do_nsigmas, do_dry, do_skews,
|
|
|
+ do_temp_last)
|
|
|
|
|
|
@classmethod
|
|
|
def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
|
|
@@ -620,7 +626,7 @@ class SamplingTensors:
|
|
|
dry_multipliers: List[float], dry_bases: List[float],
|
|
|
dry_allowed_lengths: List[int],
|
|
|
dry_sequence_breaker_ids: List[List[int]],
|
|
|
- sampling_seeds: List[List[int]],
|
|
|
+ skews: List[float], sampling_seeds: List[List[int]],
|
|
|
sample_indices: List[int], prompt_tokens: List[array],
|
|
|
output_tokens: List[array], vocab_size: int,
|
|
|
extra_seeds_to_generate: int, device: torch.device,
|
|
@@ -786,6 +792,12 @@ class SamplingTensors:
|
|
|
dtype=torch.long,
|
|
|
pin_memory=pin_memory,
|
|
|
)
|
|
|
+ skews_t = torch.tensor(
|
|
|
+ skews,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
|
|
|
sample_indices_t = torch.tensor(
|
|
|
sample_indices,
|
|
@@ -853,6 +865,7 @@ class SamplingTensors:
|
|
|
non_blocking=True),
|
|
|
dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device,
|
|
|
non_blocking=True),
|
|
|
+ skews=skews_t.to(device=device, non_blocking=True),
|
|
|
typical_ps=typical_ps_t.to(device=device, non_blocking=True),
|
|
|
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
|
|
output_tokens=output_t.to(device=device, non_blocking=True),
|