|
@@ -367,6 +367,9 @@ class SamplingTensors:
|
|
"""Tensors for sampling."""
|
|
"""Tensors for sampling."""
|
|
|
|
|
|
temperatures: torch.Tensor
|
|
temperatures: torch.Tensor
|
|
|
|
+ dynatemp_mins: torch.Tensor
|
|
|
|
+ dynatemp_maxs: torch.Tensor
|
|
|
|
+ dynatemp_exps: torch.Tensor
|
|
temperature_lasts: torch.Tensor
|
|
temperature_lasts: torch.Tensor
|
|
top_ps: torch.Tensor
|
|
top_ps: torch.Tensor
|
|
top_ks: torch.Tensor
|
|
top_ks: torch.Tensor
|
|
@@ -400,7 +403,7 @@ class SamplingTensors:
|
|
extra_seeds_to_generate: int = 0,
|
|
extra_seeds_to_generate: int = 0,
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
- bool, bool, bool, bool]:
|
|
|
|
|
|
+ bool, bool, bool, bool, bool]:
|
|
"""
|
|
"""
|
|
extra_seeds_to_generate: extra seeds to generate using the
|
|
extra_seeds_to_generate: extra seeds to generate using the
|
|
user-defined seed for each sequence.
|
|
user-defined seed for each sequence.
|
|
@@ -410,6 +413,9 @@ class SamplingTensors:
|
|
output_tokens: List[array] = []
|
|
output_tokens: List[array] = []
|
|
top_ks: List[int] = []
|
|
top_ks: List[int] = []
|
|
temperatures: List[float] = []
|
|
temperatures: List[float] = []
|
|
|
|
+ dynatemp_mins: List[float] = []
|
|
|
|
+ dynatemp_maxs: List[float] = []
|
|
|
|
+ dynatemp_exps: List[float] = []
|
|
temperature_lasts: List[bool] = []
|
|
temperature_lasts: List[bool] = []
|
|
top_ps: List[float] = []
|
|
top_ps: List[float] = []
|
|
top_as: List[float] = []
|
|
top_as: List[float] = []
|
|
@@ -428,6 +434,7 @@ class SamplingTensors:
|
|
sampling_seeds: List[int] = []
|
|
sampling_seeds: List[int] = []
|
|
sample_indices: List[int] = []
|
|
sample_indices: List[int] = []
|
|
do_penalties = False
|
|
do_penalties = False
|
|
|
|
+ do_temperatures = False
|
|
do_top_p_top_k = False
|
|
do_top_p_top_k = False
|
|
do_top_as = False
|
|
do_top_as = False
|
|
do_min_p = False
|
|
do_min_p = False
|
|
@@ -451,6 +458,9 @@ class SamplingTensors:
|
|
seq_ids = seq_group.seq_ids
|
|
seq_ids = seq_group.seq_ids
|
|
sampling_params = seq_group.sampling_params
|
|
sampling_params = seq_group.sampling_params
|
|
temperature = sampling_params.temperature
|
|
temperature = sampling_params.temperature
|
|
|
|
+ dynatemp_min = sampling_params.dynatemp_min
|
|
|
|
+ dynatemp_max = sampling_params.dynatemp_max
|
|
|
|
+ dynatemp_exp = sampling_params.dynatemp_exponent
|
|
temperature_last = sampling_params.temperature_last
|
|
temperature_last = sampling_params.temperature_last
|
|
p = sampling_params.presence_penalty
|
|
p = sampling_params.presence_penalty
|
|
f = sampling_params.frequency_penalty
|
|
f = sampling_params.frequency_penalty
|
|
@@ -475,6 +485,8 @@ class SamplingTensors:
|
|
# (i.e., greedy sampling or beam search).
|
|
# (i.e., greedy sampling or beam search).
|
|
# Set the temperature to 1 to avoid division by zero.
|
|
# Set the temperature to 1 to avoid division by zero.
|
|
temperature = 1.0
|
|
temperature = 1.0
|
|
|
|
+ if not do_temperatures and temperature != 1.0:
|
|
|
|
+ do_temperatures = True
|
|
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
|
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
|
or top_k != vocab_size):
|
|
or top_k != vocab_size):
|
|
do_top_p_top_k = True
|
|
do_top_p_top_k = True
|
|
@@ -510,6 +522,9 @@ class SamplingTensors:
|
|
assert query_len is not None
|
|
assert query_len is not None
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
temperatures += [temperature] * prefill_len
|
|
temperatures += [temperature] * prefill_len
|
|
|
|
+ dynatemp_mins += [dynatemp_min] * prefill_len
|
|
|
|
+ dynatemp_maxs += [dynatemp_max] * prefill_len
|
|
|
|
+ dynatemp_exps += [dynatemp_exp] * prefill_len
|
|
temperature_lasts += [temperature_last] * prefill_len
|
|
temperature_lasts += [temperature_last] * prefill_len
|
|
top_ps += [top_p] * prefill_len
|
|
top_ps += [top_p] * prefill_len
|
|
top_ks += [top_k] * prefill_len
|
|
top_ks += [top_k] * prefill_len
|
|
@@ -531,6 +546,9 @@ class SamplingTensors:
|
|
sample_lens = len(seq_group.sample_indices)
|
|
sample_lens = len(seq_group.sample_indices)
|
|
assert sample_lens == len(seq_ids)
|
|
assert sample_lens == len(seq_ids)
|
|
temperatures += [temperature] * len(seq_ids)
|
|
temperatures += [temperature] * len(seq_ids)
|
|
|
|
+ dynatemp_mins += [dynatemp_min] * len(seq_ids)
|
|
|
|
+ dynatemp_maxs += [dynatemp_max] * len(seq_ids)
|
|
|
|
+ dynatemp_exps += [dynatemp_exp] * len(seq_ids)
|
|
temperature_lasts += [temperature_last] * len(seq_ids)
|
|
temperature_lasts += [temperature_last] * len(seq_ids)
|
|
top_ps += [top_p] * len(seq_ids)
|
|
top_ps += [top_p] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
@@ -587,18 +605,21 @@ class SamplingTensors:
|
|
output_tokens.append(seq_data.output_token_ids_array)
|
|
output_tokens.append(seq_data.output_token_ids_array)
|
|
|
|
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
- temperatures, temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
|
|
|
|
+ temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
|
|
|
|
+ temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
presence_penalties, frequency_penalties, repetition_penalties,
|
|
presence_penalties, frequency_penalties, repetition_penalties,
|
|
tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
|
|
tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
|
|
smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
|
|
smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
|
|
sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
extra_seeds_to_generate, device, dtype)
|
|
extra_seeds_to_generate, device, dtype)
|
|
- return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
|
|
|
|
- do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
|
|
|
|
- do_typical_ps, do_quadratic, do_xtc, do_temp_last)
|
|
|
|
|
|
+ return (sampling_tensors, do_penalties, do_temperatures,
|
|
|
|
+ do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
|
|
|
|
+ do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
|
|
|
|
+ do_temp_last)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def from_lists(cls, temperatures: List[float],
|
|
|
|
|
|
+ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
|
|
|
|
+ dynatemp_maxs: List[float], dynatemp_exps: List[float],
|
|
temperature_lasts: List[bool], top_ps: List[float],
|
|
temperature_lasts: List[bool], top_ps: List[float],
|
|
top_ks: List[int], top_as: List[float],
|
|
top_ks: List[int], top_as: List[float],
|
|
min_ps: List[float], presence_penalties: List[float],
|
|
min_ps: List[float], presence_penalties: List[float],
|
|
@@ -643,6 +664,24 @@ class SamplingTensors:
|
|
dtype=dtype,
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
pin_memory=pin_memory,
|
|
)
|
|
)
|
|
|
|
+ dynatemp_mins_t = torch.tensor(
|
|
|
|
+ dynatemp_mins,
|
|
|
|
+ device="cpu",
|
|
|
|
+ dtype=dtype,
|
|
|
|
+ pin_memory=pin_memory,
|
|
|
|
+ )
|
|
|
|
+ dynatemp_maxs_t = torch.tensor(
|
|
|
|
+ dynatemp_maxs,
|
|
|
|
+ device="cpu",
|
|
|
|
+ dtype=dtype,
|
|
|
|
+ pin_memory=pin_memory,
|
|
|
|
+ )
|
|
|
|
+ dynatemp_exps_t = torch.tensor(
|
|
|
|
+ dynatemp_exps,
|
|
|
|
+ device="cpu",
|
|
|
|
+ dtype=dtype,
|
|
|
|
+ pin_memory=pin_memory,
|
|
|
|
+ )
|
|
temp_lasts_t = torch.tensor(
|
|
temp_lasts_t = torch.tensor(
|
|
temperature_lasts,
|
|
temperature_lasts,
|
|
device="cpu",
|
|
device="cpu",
|
|
@@ -751,6 +790,9 @@ class SamplingTensors:
|
|
|
|
|
|
return cls(
|
|
return cls(
|
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
|
|
+ dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
|
|
|
|
+ dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
|
|
|
|
+ dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
|
|
temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
|
|
temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
|
|
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
|
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
|
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
|
top_ks=top_ks_t.to(device=device, non_blocking=True),
|