|
@@ -379,6 +379,7 @@ class SamplingTensors:
|
|
|
presence_penalties: torch.Tensor
|
|
|
frequency_penalties: torch.Tensor
|
|
|
repetition_penalties: torch.Tensor
|
|
|
+ no_repeat_ngram_sizes: torch.Tensor
|
|
|
tfss: torch.Tensor
|
|
|
eta_cutoffs: torch.Tensor
|
|
|
epsilon_cutoffs: torch.Tensor
|
|
@@ -409,7 +410,7 @@ class SamplingTensors:
|
|
|
extra_seeds_to_generate: int = 0,
|
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
- bool, bool, bool, bool, bool, bool, bool]:
|
|
|
+ bool, bool, bool, bool, bool, bool, bool, bool]:
|
|
|
"""
|
|
|
extra_seeds_to_generate: extra seeds to generate using the
|
|
|
user-defined seed for each sequence.
|
|
@@ -429,6 +430,7 @@ class SamplingTensors:
|
|
|
presence_penalties: List[float] = []
|
|
|
frequency_penalties: List[float] = []
|
|
|
repetition_penalties: List[float] = []
|
|
|
+ no_repeat_ngram_sizes: List[int] = []
|
|
|
tfss: List[float] = []
|
|
|
eta_cutoffs: List[float] = []
|
|
|
epsilon_cutoffs: List[float] = []
|
|
@@ -446,6 +448,7 @@ class SamplingTensors:
|
|
|
dry_sequence_breaker_ids: List[List[int]] = []
|
|
|
|
|
|
do_penalties = False
|
|
|
+ do_no_repeat_ngrams = False
|
|
|
do_temperatures = False
|
|
|
do_top_p_top_k = False
|
|
|
do_top_as = False
|
|
@@ -493,6 +496,7 @@ class SamplingTensors:
|
|
|
do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
|
|
|
abs(params.frequency_penalty) >= _SAMPLING_EPS or
|
|
|
params.repetition_penalty > 1.0)
|
|
|
+ do_no_repeat_ngrams |= params.no_repeat_ngram_size > 0
|
|
|
do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
|
|
|
do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
|
|
|
do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
|
|
@@ -529,6 +533,7 @@ class SamplingTensors:
|
|
|
presence_penalties += [params.presence_penalty] * n_seqs
|
|
|
frequency_penalties += [params.frequency_penalty] * n_seqs
|
|
|
repetition_penalties += [params.repetition_penalty] * n_seqs
|
|
|
+ no_repeat_ngram_sizes += [params.no_repeat_ngram_size] * n_seqs
|
|
|
tfss += [params.tfs] * n_seqs
|
|
|
eta_cutoffs += [params.eta_cutoff] * n_seqs
|
|
|
epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
|
|
@@ -566,7 +571,7 @@ class SamplingTensors:
|
|
|
sampling_seeds.append(seq_seeds)
|
|
|
sample_indices.extend(seq_group.sample_indices)
|
|
|
|
|
|
- if do_penalties or do_dry:
|
|
|
+ if do_penalties or do_dry or do_no_repeat_ngrams:
|
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
|
seq_ids = seq_group.seq_ids
|
|
|
if (seq_group.is_prompt
|
|
@@ -588,16 +593,16 @@ class SamplingTensors:
|
|
|
temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
|
|
|
temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
|
presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
- tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
|
|
|
- smoothing_curves, xtc_thresholds, xtc_probabilities, nsigmas,
|
|
|
- dry_multipliers, dry_bases, dry_allowed_lengths,
|
|
|
- dry_sequence_breaker_ids, sampling_seeds, sample_indices,
|
|
|
- prompt_tokens, output_tokens, vocab_size, extra_seeds_to_generate,
|
|
|
- device, dtype)
|
|
|
- return (sampling_tensors, do_penalties, do_temperatures,
|
|
|
- do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
|
|
|
- do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
|
|
|
- do_nsigmas, do_dry, do_temp_last)
|
|
|
+ no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
|
|
|
+ typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
|
|
|
+ xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
|
|
|
+ dry_allowed_lengths, dry_sequence_breaker_ids, sampling_seeds,
|
|
|
+ sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
|
+ extra_seeds_to_generate, device, dtype)
|
|
|
+ return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
|
|
|
+ do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
|
|
|
+ do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
|
|
|
+ do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last)
|
|
|
|
|
|
@classmethod
|
|
|
def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
|
|
@@ -606,7 +611,8 @@ class SamplingTensors:
|
|
|
top_ks: List[int], top_as: List[float],
|
|
|
min_ps: List[float], presence_penalties: List[float],
|
|
|
frequency_penalties: List[float],
|
|
|
- repetition_penalties: List[float], tfss: List[float],
|
|
|
+ repetition_penalties: List[float],
|
|
|
+ no_repeat_ngram_sizes: List[int], tfss: List[float],
|
|
|
eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
typical_ps: List[float], smoothing_factors: List[float],
|
|
|
smoothing_curves: List[float], xtc_thresholds: List[float],
|
|
@@ -708,6 +714,12 @@ class SamplingTensors:
|
|
|
dtype=dtype,
|
|
|
pin_memory=pin_memory,
|
|
|
)
|
|
|
+ no_repeat_ngram_sizes_t = torch.tensor(
|
|
|
+ no_repeat_ngram_sizes,
|
|
|
+ device="cpu",
|
|
|
+ dtype=torch.int,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
top_ks_t = torch.tensor(
|
|
|
top_ks,
|
|
|
device="cpu",
|
|
@@ -819,6 +831,8 @@ class SamplingTensors:
|
|
|
non_blocking=True),
|
|
|
repetition_penalties=repetition_penalties_t.to(device=device,
|
|
|
non_blocking=True),
|
|
|
+ no_repeat_ngram_sizes=no_repeat_ngram_sizes_t.to(device=device,
|
|
|
+ non_blocking=True),
|
|
|
tfss=tfss_t.to(device=device, non_blocking=True),
|
|
|
eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
|
|
|
epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
|