|
@@ -431,7 +431,7 @@ class SamplingTensors:
|
|
|
smoothing_curves: List[float] = []
|
|
|
xtc_thresholds: List[float] = []
|
|
|
xtc_probabilities: List[float] = []
|
|
|
- sampling_seeds: List[int] = []
|
|
|
+ sampling_seeds: List[List[int]] = []
|
|
|
sample_indices: List[int] = []
|
|
|
do_penalties = False
|
|
|
do_temperatures = False
|
|
@@ -456,124 +456,93 @@ class SamplingTensors:
|
|
|
assert sampling_metadata.seq_groups is not None
|
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
|
seq_ids = seq_group.seq_ids
|
|
|
- sampling_params = seq_group.sampling_params
|
|
|
- temperature = sampling_params.temperature
|
|
|
- dynatemp_min = sampling_params.dynatemp_min
|
|
|
- dynatemp_max = sampling_params.dynatemp_max
|
|
|
- dynatemp_exp = sampling_params.dynatemp_exponent
|
|
|
- temperature_last = sampling_params.temperature_last
|
|
|
- p = sampling_params.presence_penalty
|
|
|
- f = sampling_params.frequency_penalty
|
|
|
- r = sampling_params.repetition_penalty
|
|
|
- top_p = sampling_params.top_p
|
|
|
- top_a = sampling_params.top_a
|
|
|
- min_p = sampling_params.min_p
|
|
|
- tfs = sampling_params.tfs
|
|
|
- eta_cutoff = sampling_params.eta_cutoff
|
|
|
- epsilon_cutoff = sampling_params.epsilon_cutoff
|
|
|
- typical_p = sampling_params.typical_p
|
|
|
- smoothing_factor = sampling_params.smoothing_factor
|
|
|
- smoothing_curve = sampling_params.smoothing_curve
|
|
|
- xtc_threshold = sampling_params.xtc_threshold
|
|
|
- xtc_probability = sampling_params.xtc_probability
|
|
|
+ params = seq_group.sampling_params
|
|
|
|
|
|
# k should not be greater than the vocab size.
|
|
|
- top_k = min(sampling_params.top_k, vocab_size)
|
|
|
+ top_k = min(params.top_k, vocab_size)
|
|
|
top_k = vocab_size if top_k == -1 else top_k
|
|
|
+
|
|
|
+ temperature = params.temperature
|
|
|
if temperature < _SAMPLING_EPS:
|
|
|
# NOTE: Zero temperature means deterministic sampling
|
|
|
# (i.e., greedy sampling or beam search).
|
|
|
# Set the temperature to 1 to avoid division by zero.
|
|
|
temperature = 1.0
|
|
|
- if not do_temperatures and temperature != 1.0:
|
|
|
- do_temperatures = True
|
|
|
- if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
|
|
- or top_k != vocab_size):
|
|
|
- do_top_p_top_k = True
|
|
|
- if do_top_as is False and top_a > 0.0:
|
|
|
- do_top_as = True
|
|
|
- if not do_min_p and min_p > _SAMPLING_EPS:
|
|
|
- do_min_p = True
|
|
|
- if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
|
|
- or abs(f) >= _SAMPLING_EPS
|
|
|
- or abs(r - 1.0) >= _SAMPLING_EPS):
|
|
|
- do_penalties = True
|
|
|
- if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
|
|
|
- do_tfss = True
|
|
|
- if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
|
|
|
- do_eta_cutoffs = True
|
|
|
- if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
|
|
|
- do_epsilon_cutoffs = True
|
|
|
- if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
|
|
|
- do_typical_ps = True
|
|
|
- if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
|
|
|
- or smoothing_curve > 1.0):
|
|
|
- do_quadratic = True
|
|
|
- if do_xtc is False and xtc_probability > _SAMPLING_EPS:
|
|
|
- do_xtc = True
|
|
|
- if do_temp_last is False and temperature_last:
|
|
|
- do_temp_last = True
|
|
|
+
|
|
|
+ do_temperatures |= (temperature != 1.0 or
|
|
|
+ params.dynatemp_min > _SAMPLING_EPS or
|
|
|
+ params.dynatemp_max > _SAMPLING_EPS)
|
|
|
+ do_top_p_top_k |= (params.top_p < 1.0 - _SAMPLING_EPS or
|
|
|
+ top_k != vocab_size)
|
|
|
+ do_top_as |= params.top_a > 0.0
|
|
|
+ do_min_p |= params.min_p > _SAMPLING_EPS
|
|
|
+ do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
|
|
|
+ abs(params.frequency_penalty) >= _SAMPLING_EPS or
|
|
|
+ params.repetition_penalty > 1.0)
|
|
|
+ do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
|
|
|
+ do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
|
|
|
+ do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
|
|
|
+ do_typical_ps |= params.typical_p < 1.0 - _SAMPLING_EPS
|
|
|
+ do_quadratic |= (params.smoothing_factor > _SAMPLING_EPS or
|
|
|
+ params.smoothing_curve > 1.0)
|
|
|
+ do_xtc |= params.xtc_probability > _SAMPLING_EPS
|
|
|
+ do_temp_last |= params.temperature_last
|
|
|
|
|
|
is_prompt = seq_group.is_prompt
|
|
|
- if (is_prompt and sampling_params.prompt_logprobs is not None):
|
|
|
- # For tokens in the prompt that we only need to get
|
|
|
- # their logprobs
|
|
|
- query_len = seq_group.query_len
|
|
|
- assert query_len is not None
|
|
|
- prefill_len = len(seq_group.prompt_logprob_indices)
|
|
|
- temperatures += [temperature] * prefill_len
|
|
|
- dynatemp_mins += [dynatemp_min] * prefill_len
|
|
|
- dynatemp_maxs += [dynatemp_max] * prefill_len
|
|
|
- dynatemp_exps += [dynatemp_exp] * prefill_len
|
|
|
- temperature_lasts += [temperature_last] * prefill_len
|
|
|
- top_ps += [top_p] * prefill_len
|
|
|
- top_ks += [top_k] * prefill_len
|
|
|
- top_as += [top_a] * prefill_len
|
|
|
- min_ps += [min_p] * prefill_len
|
|
|
- presence_penalties += [0] * prefill_len
|
|
|
- frequency_penalties += [0] * prefill_len
|
|
|
- repetition_penalties += [1] * prefill_len
|
|
|
- tfss += [1] * prefill_len
|
|
|
- eta_cutoffs += [0] * prefill_len
|
|
|
- epsilon_cutoffs += [0] * prefill_len
|
|
|
- typical_ps += [1] * prefill_len
|
|
|
- smoothing_factors += [smoothing_factor] * prefill_len
|
|
|
- smoothing_curves += [smoothing_curve] * prefill_len
|
|
|
- xtc_thresholds += [xtc_threshold] * prefill_len
|
|
|
- xtc_probabilities += [xtc_probability] * prefill_len
|
|
|
+ wants_prompt_logprobs = params.prompt_logprobs is not None
|
|
|
+
|
|
|
+ n_prompt_seqs, n_sample_seqs = 0, 0
|
|
|
+ if seq_group.is_prompt and wants_prompt_logprobs:
|
|
|
+ assert seq_group.query_len is not None
|
|
|
+ n_prompt_seqs = len(seq_group.prompt_logprob_indices)
|
|
|
|
|
|
if seq_group.do_sample:
|
|
|
- sample_lens = len(seq_group.sample_indices)
|
|
|
- assert sample_lens == len(seq_ids)
|
|
|
- temperatures += [temperature] * len(seq_ids)
|
|
|
- dynatemp_mins += [dynatemp_min] * len(seq_ids)
|
|
|
- dynatemp_maxs += [dynatemp_max] * len(seq_ids)
|
|
|
- dynatemp_exps += [dynatemp_exp] * len(seq_ids)
|
|
|
- temperature_lasts += [temperature_last] * len(seq_ids)
|
|
|
- top_ps += [top_p] * len(seq_ids)
|
|
|
- top_ks += [top_k] * len(seq_ids)
|
|
|
- top_as += [top_a] * len(seq_ids)
|
|
|
- min_ps += [min_p] * len(seq_ids)
|
|
|
- presence_penalties += [p] * len(seq_ids)
|
|
|
- frequency_penalties += [f] * len(seq_ids)
|
|
|
- repetition_penalties += [r] * len(seq_ids)
|
|
|
- tfss += [tfs] * len(seq_ids)
|
|
|
- eta_cutoffs += [eta_cutoff] * len(seq_ids)
|
|
|
- epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
|
|
|
- typical_ps += [typical_p] * len(seq_ids)
|
|
|
- smoothing_factors += [smoothing_factor] * len(seq_ids)
|
|
|
- smoothing_curves += [smoothing_curve] * len(seq_ids)
|
|
|
- xtc_thresholds += [xtc_threshold] * len(seq_ids)
|
|
|
- xtc_probabilities += [xtc_probability] * len(seq_ids)
|
|
|
+ assert len(seq_group.sample_indices) == len(seq_ids)
|
|
|
+ n_sample_seqs = len(seq_ids)
|
|
|
+
|
|
|
+ n_seqs = n_prompt_seqs + n_sample_seqs
|
|
|
+
|
|
|
+ # These parameters apply to ALL sequences
|
|
|
+ temperatures += [temperature] * n_seqs
|
|
|
+ dynatemp_mins += [params.dynatemp_min] * n_seqs
|
|
|
+ dynatemp_maxs += [params.dynatemp_max] * n_seqs
|
|
|
+ dynatemp_exps += [params.dynatemp_exponent] * n_seqs
|
|
|
+ temperature_lasts += [params.temperature_last] * n_seqs
|
|
|
+ top_ps += [params.top_p] * n_seqs
|
|
|
+ top_ks += [top_k] * n_seqs
|
|
|
+ top_as += [params.top_a] * n_seqs
|
|
|
+ min_ps += [params.min_p] * n_seqs
|
|
|
+
|
|
|
+ # These parameters have a default value for prompt tokens
|
|
|
+ presence_penalties += [0] * n_prompt_seqs
|
|
|
+ presence_penalties += [params.presence_penalty] * n_sample_seqs
|
|
|
+
|
|
|
+ frequency_penalties += [0] * n_prompt_seqs
|
|
|
+ frequency_penalties += [params.frequency_penalty] * n_sample_seqs
|
|
|
+
|
|
|
+ repetition_penalties += [1] * n_prompt_seqs
|
|
|
+ repetition_penalties += [params.repetition_penalty] * n_sample_seqs
|
|
|
+
|
|
|
+ tfss += [1] * n_prompt_seqs
|
|
|
+ tfss += [params.tfs] * n_sample_seqs
|
|
|
+
|
|
|
+ eta_cutoffs += [0] * n_prompt_seqs
|
|
|
+ eta_cutoffs += [params.eta_cutoff] * n_sample_seqs
|
|
|
+
|
|
|
+ epsilon_cutoffs += [0] * n_prompt_seqs
|
|
|
+ epsilon_cutoffs += [params.epsilon_cutoff] * n_sample_seqs
|
|
|
+
|
|
|
+ typical_ps += [1] * n_prompt_seqs
|
|
|
+ typical_ps += [params.typical_p] * n_sample_seqs
|
|
|
|
|
|
if _USE_TRITON_SAMPLER:
|
|
|
if is_prompt:
|
|
|
- prompt_best_of.append(sampling_params.best_of)
|
|
|
+ prompt_best_of.append(params.best_of)
|
|
|
query_len = seq_group.query_len
|
|
|
assert query_len is not None
|
|
|
|
|
|
- seed = sampling_params.seed
|
|
|
- is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
|
|
+ seed = params.seed
|
|
|
+ is_greedy = params.sampling_type == SamplingType.GREEDY
|
|
|
|
|
|
for seq_id in seq_ids:
|
|
|
seq_data = seq_group.seq_data[seq_id]
|
|
@@ -592,7 +561,7 @@ class SamplingTensors:
|
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
|
seq_ids = seq_group.seq_ids
|
|
|
if (seq_group.is_prompt
|
|
|
- and sampling_params.prompt_logprobs is not None):
|
|
|
+ and params.prompt_logprobs is not None):
|
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
|
prompt_tokens.extend(
|
|
|
array('l') for _ in range(prefill_len))
|
|
@@ -609,7 +578,7 @@ class SamplingTensors:
|
|
|
temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
|
presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
|
|
|
- smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
|
|
|
+ smoothing_curves, xtc_thresholds, xtc_probabilities, sampling_seeds,
|
|
|
sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
|
extra_seeds_to_generate, device, dtype)
|
|
|
return (sampling_tensors, do_penalties, do_temperatures,
|
|
@@ -628,7 +597,8 @@ class SamplingTensors:
|
|
|
eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
typical_ps: List[float], smoothing_factors: List[float],
|
|
|
smoothing_curves: List[float], xtc_thresholds: List[float],
|
|
|
- xtc_probabilities: List[float], sampling_seeds: List[int],
|
|
|
+ xtc_probabilities: List[float],
|
|
|
+ sampling_seeds: List[List[int]],
|
|
|
sample_indices: List[int], prompt_tokens: List[array],
|
|
|
output_tokens: List[array], vocab_size: int,
|
|
|
extra_seeds_to_generate: int, device: torch.device,
|
|
@@ -827,7 +797,7 @@ class SamplingTensors:
|
|
|
|
|
|
@staticmethod
|
|
|
def _get_sequence_seeds(
|
|
|
- seed: int,
|
|
|
+ seed: int|None,
|
|
|
*extra_entropy: int,
|
|
|
seeds_to_generate: int,
|
|
|
is_greedy: bool,
|