|
@@ -386,6 +386,10 @@ class SamplingTensors:
|
|
|
smoothing_curves: torch.Tensor
|
|
|
xtc_thresholds: torch.Tensor
|
|
|
xtc_probabilities: torch.Tensor
|
|
|
+ kl_thresholds: torch.Tensor
|
|
|
+ jsd_thresholds: torch.Tensor
|
|
|
+ min_typical_ps: torch.Tensor
|
|
|
+ max_typical_ps: torch.Tensor
|
|
|
sampling_seeds: torch.Tensor
|
|
|
sample_indices: torch.Tensor
|
|
|
extra_seeds: Optional[torch.Tensor]
|
|
@@ -403,7 +407,7 @@ class SamplingTensors:
|
|
|
extra_seeds_to_generate: int = 0,
|
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
- bool, bool, bool, bool, bool]:
|
|
|
+ bool, bool, bool, bool, bool, bool, bool, bool]:
|
|
|
"""
|
|
|
extra_seeds_to_generate: extra seeds to generate using the
|
|
|
user-defined seed for each sequence.
|
|
@@ -431,6 +435,10 @@ class SamplingTensors:
|
|
|
smoothing_curves: List[float] = []
|
|
|
xtc_thresholds: List[float] = []
|
|
|
xtc_probabilities: List[float] = []
|
|
|
+ kl_thresholds: List[float] = []
|
|
|
+ jsd_thresholds: List[float] = []
|
|
|
+ min_typical_ps: List[float] = []
|
|
|
+ max_typical_ps: List[float] = []
|
|
|
sampling_seeds: List[int] = []
|
|
|
sample_indices: List[int] = []
|
|
|
do_penalties = False
|
|
@@ -444,6 +452,9 @@ class SamplingTensors:
|
|
|
do_typical_ps = False
|
|
|
do_quadratic = False
|
|
|
do_xtc = False
|
|
|
+ do_kl_threshold = False
|
|
|
+ do_jsd_threshold = False
|
|
|
+ do_dynatypical_p = False
|
|
|
do_temp_last = False
|
|
|
|
|
|
if _USE_TRITON_SAMPLER:
|
|
@@ -476,6 +487,10 @@ class SamplingTensors:
|
|
|
smoothing_curve = sampling_params.smoothing_curve
|
|
|
xtc_threshold = sampling_params.xtc_threshold
|
|
|
xtc_probability = sampling_params.xtc_probability
|
|
|
+ kl_threshold = sampling_params.kl_threshold
|
|
|
+ jsd_threshold = sampling_params.jsd_threshold
|
|
|
+ min_typical_p = sampling_params.min_typical_p
|
|
|
+ max_typical_p = sampling_params.max_typical_p
|
|
|
|
|
|
# k should not be greater than the vocab size.
|
|
|
top_k = min(sampling_params.top_k, vocab_size)
|
|
@@ -511,6 +526,13 @@ class SamplingTensors:
|
|
|
do_quadratic = True
|
|
|
if do_xtc is False and xtc_probability > _SAMPLING_EPS:
|
|
|
do_xtc = True
|
|
|
+ if do_kl_threshold is False and kl_threshold > _SAMPLING_EPS:
|
|
|
+ do_kl_threshold = True
|
|
|
+ if do_jsd_threshold is False and jsd_threshold > _SAMPLING_EPS:
|
|
|
+ do_jsd_threshold = True
|
|
|
+ if do_dynatypical_p is False and (min_typical_p < 1.0 - _SAMPLING_EPS
|
|
|
+ or max_typical_p < 1.0 - _SAMPLING_EPS):
|
|
|
+ do_dynatypical_p = True
|
|
|
if do_temp_last is False and temperature_last:
|
|
|
do_temp_last = True
|
|
|
|
|
@@ -541,6 +563,10 @@ class SamplingTensors:
|
|
|
smoothing_curves += [smoothing_curve] * prefill_len
|
|
|
xtc_thresholds += [xtc_threshold] * prefill_len
|
|
|
xtc_probabilities += [xtc_probability] * prefill_len
|
|
|
+ kl_thresholds += [kl_threshold] * prefill_len
|
|
|
+ jsd_thresholds += [jsd_threshold] * prefill_len
|
|
|
+ min_typical_ps += [min_typical_p] * prefill_len
|
|
|
+ max_typical_ps += [max_typical_p] * prefill_len
|
|
|
|
|
|
if seq_group.do_sample:
|
|
|
sample_lens = len(seq_group.sample_indices)
|
|
@@ -565,6 +591,10 @@ class SamplingTensors:
|
|
|
smoothing_curves += [smoothing_curve] * len(seq_ids)
|
|
|
xtc_thresholds += [xtc_threshold] * len(seq_ids)
|
|
|
xtc_probabilities += [xtc_probability] * len(seq_ids)
|
|
|
+ kl_thresholds += [kl_threshold] * len(seq_ids)
|
|
|
+ jsd_thresholds += [jsd_threshold] * len(seq_ids)
|
|
|
+ min_typical_ps += [min_typical_p] * len(seq_ids)
|
|
|
+ max_typical_ps += [max_typical_p] * len(seq_ids)
|
|
|
|
|
|
if _USE_TRITON_SAMPLER:
|
|
|
if is_prompt:
|
|
@@ -609,12 +639,14 @@ class SamplingTensors:
|
|
|
temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
|
presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
|
|
|
- smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
|
|
|
- sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
|
- extra_seeds_to_generate, device, dtype)
|
|
|
+ smoothing_curves, xtc_thresholds, xtc_probabilities, kl_thresholds,
|
|
|
+ jsd_thresholds, min_typical_ps, max_typical_ps,
|
|
|
+ sampling_seeds, sample_indices, prompt_tokens, output_tokens,
|
|
|
+ vocab_size, extra_seeds_to_generate, device, dtype)
|
|
|
return (sampling_tensors, do_penalties, do_temperatures,
|
|
|
do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
|
|
|
do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
|
|
|
+ do_kl_threshold, do_jsd_threshold, do_dynatypical_p,
|
|
|
do_temp_last)
|
|
|
|
|
|
@classmethod
|
|
@@ -628,7 +660,9 @@ class SamplingTensors:
|
|
|
eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
typical_ps: List[float], smoothing_factors: List[float],
|
|
|
smoothing_curves: List[float], xtc_thresholds: List[float],
|
|
|
- xtc_probabilities: List[float], sampling_seeds: List[int],
|
|
|
+ xtc_probabilities: List[float], kl_thresholds: List[float],
|
|
|
+ jsd_thresholds: List[float], min_typical_ps: List[float],
|
|
|
+ max_typical_ps: List[float], sampling_seeds: List[int],
|
|
|
sample_indices: List[int], prompt_tokens: List[array],
|
|
|
output_tokens: List[array], vocab_size: int,
|
|
|
extra_seeds_to_generate: int, device: torch.device,
|
|
@@ -760,6 +794,22 @@ class SamplingTensors:
|
|
|
device="cpu",
|
|
|
dtype=dtype,
|
|
|
pin_memory=pin_memory)
|
|
|
+ kl_thresholds_t = torch.tensor(kl_thresholds,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory)
|
|
|
+ jsd_thresholds_t = torch.tensor(jsd_thresholds,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory)
|
|
|
+ min_typical_ps_t = torch.tensor(min_typical_ps,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory)
|
|
|
+ max_typical_ps_t = torch.tensor(max_typical_ps,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory)
|
|
|
sample_indices_t = torch.tensor(
|
|
|
sample_indices,
|
|
|
device="cpu",
|
|
@@ -816,6 +866,10 @@ class SamplingTensors:
|
|
|
non_blocking=True),
|
|
|
xtc_probabilities=xtc_probabilities_t.to(device=device,
|
|
|
non_blocking=True),
|
|
|
+ kl_thresholds=kl_thresholds_t.to(device=device, non_blocking=True),
|
|
|
+ jsd_thresholds=jsd_thresholds_t.to(device=device, non_blocking=True),
|
|
|
+ min_typical_ps=min_typical_ps_t.to(device=device, non_blocking=True),
|
|
|
+ max_typical_ps=max_typical_ps_t.to(device=device, non_blocking=True),
|
|
|
typical_ps=typical_ps_t.to(device=device, non_blocking=True),
|
|
|
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
|
|
output_tokens=output_t.to(device=device, non_blocking=True),
|