|
@@ -491,18 +491,15 @@ class SamplingTensors:
|
|
|
is_prompt = seq_group.is_prompt
|
|
|
wants_prompt_logprobs = params.prompt_logprobs is not None
|
|
|
|
|
|
- n_prompt_seqs, n_sample_seqs = 0, 0
|
|
|
+ n_seqs = 0
|
|
|
if seq_group.is_prompt and wants_prompt_logprobs:
|
|
|
assert seq_group.query_len is not None
|
|
|
- n_prompt_seqs = len(seq_group.prompt_logprob_indices)
|
|
|
+ n_seqs += len(seq_group.prompt_logprob_indices)
|
|
|
|
|
|
if seq_group.do_sample:
|
|
|
assert len(seq_group.sample_indices) == len(seq_ids)
|
|
|
- n_sample_seqs = len(seq_ids)
|
|
|
+ n_seqs += len(seq_ids)
|
|
|
|
|
|
- n_seqs = n_prompt_seqs + n_sample_seqs
|
|
|
-
|
|
|
- # These parameters apply to ALL sequences
|
|
|
temperatures += [temperature] * n_seqs
|
|
|
dynatemp_mins += [params.dynatemp_min] * n_seqs
|
|
|
dynatemp_maxs += [params.dynatemp_max] * n_seqs
|
|
@@ -512,28 +509,17 @@ class SamplingTensors:
|
|
|
top_ks += [top_k] * n_seqs
|
|
|
top_as += [params.top_a] * n_seqs
|
|
|
min_ps += [params.min_p] * n_seqs
|
|
|
-
|
|
|
- # These parameters have a default value for prompt tokens
|
|
|
- presence_penalties += [0] * n_prompt_seqs
|
|
|
- presence_penalties += [params.presence_penalty] * n_sample_seqs
|
|
|
-
|
|
|
- frequency_penalties += [0] * n_prompt_seqs
|
|
|
- frequency_penalties += [params.frequency_penalty] * n_sample_seqs
|
|
|
-
|
|
|
- repetition_penalties += [1] * n_prompt_seqs
|
|
|
- repetition_penalties += [params.repetition_penalty] * n_sample_seqs
|
|
|
-
|
|
|
- tfss += [1] * n_prompt_seqs
|
|
|
- tfss += [params.tfs] * n_sample_seqs
|
|
|
-
|
|
|
- eta_cutoffs += [0] * n_prompt_seqs
|
|
|
- eta_cutoffs += [params.eta_cutoff] * n_sample_seqs
|
|
|
-
|
|
|
- epsilon_cutoffs += [0] * n_prompt_seqs
|
|
|
- epsilon_cutoffs += [params.epsilon_cutoff] * n_sample_seqs
|
|
|
-
|
|
|
- typical_ps += [1] * n_prompt_seqs
|
|
|
- typical_ps += [params.typical_p] * n_sample_seqs
|
|
|
+ presence_penalties += [params.presence_penalty] * n_seqs
|
|
|
+ frequency_penalties += [params.frequency_penalty] * n_seqs
|
|
|
+ repetition_penalties += [params.repetition_penalty] * n_seqs
|
|
|
+ tfss += [params.tfs] * n_seqs
|
|
|
+ eta_cutoffs += [params.eta_cutoff] * n_seqs
|
|
|
+ epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
|
|
|
+ typical_ps += [params.typical_p] * n_seqs
|
|
|
+ smoothing_factors += [params.smoothing_factor] * n_seqs
|
|
|
+ smoothing_curves += [params.smoothing_curve] * n_seqs
|
|
|
+ xtc_thresholds += [params.xtc_threshold] * n_seqs
|
|
|
+ xtc_probabilities += [params.xtc_probability] * n_seqs
|
|
|
|
|
|
if _USE_TRITON_SAMPLER:
|
|
|
if is_prompt:
|