|
@@ -433,17 +433,10 @@ class SamplingTensors:
|
|
typical_ps += [1] * prefill_len
|
|
typical_ps += [1] * prefill_len
|
|
smoothing_factors += [smoothing_factor] * prefill_len
|
|
smoothing_factors += [smoothing_factor] * prefill_len
|
|
smoothing_curves += [smoothing_curve] * prefill_len
|
|
smoothing_curves += [smoothing_curve] * prefill_len
|
|
- if do_penalties:
|
|
|
|
- prompt_tokens.extend([] for _ in range(prefill_len))
|
|
|
|
- output_tokens.extend([] for _ in range(prefill_len))
|
|
|
|
|
|
|
|
if seq_group.do_sample:
|
|
if seq_group.do_sample:
|
|
sample_lens = len(seq_group.sample_indices)
|
|
sample_lens = len(seq_group.sample_indices)
|
|
assert sample_lens == len(seq_ids)
|
|
assert sample_lens == len(seq_ids)
|
|
- for seq_id in seq_ids:
|
|
|
|
- seq_data = seq_group.seq_data[seq_id]
|
|
|
|
- prompt_tokens.append(seq_data.prompt_token_ids)
|
|
|
|
- output_tokens.append(seq_data.output_token_ids)
|
|
|
|
temperatures += [temperature] * len(seq_ids)
|
|
temperatures += [temperature] * len(seq_ids)
|
|
top_ps += [top_p] * len(seq_ids)
|
|
top_ps += [top_p] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
@@ -477,6 +470,20 @@ class SamplingTensors:
|
|
sampling_seeds.append(seq_seeds)
|
|
sampling_seeds.append(seq_seeds)
|
|
sample_indices.extend(seq_group.sample_indices)
|
|
sample_indices.extend(seq_group.sample_indices)
|
|
|
|
|
|
|
|
+ if do_penalties:
|
|
|
|
+ for seq_group in sampling_metadata.seq_groups:
|
|
|
|
+ seq_ids = seq_group.seq_ids
|
|
|
|
+ if (seq_group.is_prompt
|
|
|
|
+ and sampling_params.prompt_logprobs is not None):
|
|
|
|
+ prefill_len = len(seq_group.prompt_logprob_indices)
|
|
|
|
+ prompt_tokens.extend([] for _ in range(prefill_len))
|
|
|
|
+ output_tokens.extend([] for _ in range(prefill_len))
|
|
|
|
+ if seq_group.do_sample:
|
|
|
|
+ for seq_id in seq_ids:
|
|
|
|
+ seq_data = seq_group.seq_data[seq_id]
|
|
|
|
+ prompt_tokens.append(seq_data.prompt_token_ids)
|
|
|
|
+ output_tokens.append(seq_data.output_token_ids)
|
|
|
|
+
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
|
|
temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
|
|
frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
|
|
frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
|