Просмотр исходного кода

fix: greedy sampling being not greedy in concurrent situations where penalties are used

AlpinDale 7 месяцев назад
Родитель
Сommit
d1f91d0f70
1 измененных файлов с 14 добавлено и 7 удалено
  1. 14 7
      aphrodite/modeling/sampling_metadata.py

+ 14 - 7
aphrodite/modeling/sampling_metadata.py

@@ -433,17 +433,10 @@ class SamplingTensors:
                 typical_ps += [1] * prefill_len
                 typical_ps += [1] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
-                if do_penalties:
-                    prompt_tokens.extend([] for _ in range(prefill_len))
-                    output_tokens.extend([] for _ in range(prefill_len))
 
 
             if seq_group.do_sample:
             if seq_group.do_sample:
                 sample_lens = len(seq_group.sample_indices)
                 sample_lens = len(seq_group.sample_indices)
                 assert sample_lens == len(seq_ids)
                 assert sample_lens == len(seq_ids)
-                for seq_id in seq_ids:
-                    seq_data = seq_group.seq_data[seq_id]
-                    prompt_tokens.append(seq_data.prompt_token_ids)
-                    output_tokens.append(seq_data.output_token_ids)
                 temperatures += [temperature] * len(seq_ids)
                 temperatures += [temperature] * len(seq_ids)
                 top_ps += [top_p] * len(seq_ids)
                 top_ps += [top_p] * len(seq_ids)
                 top_ks += [top_k] * len(seq_ids)
                 top_ks += [top_k] * len(seq_ids)
@@ -477,6 +470,20 @@ class SamplingTensors:
                 sampling_seeds.append(seq_seeds)
                 sampling_seeds.append(seq_seeds)
             sample_indices.extend(seq_group.sample_indices)
             sample_indices.extend(seq_group.sample_indices)
 
 
+        if do_penalties:
+            for seq_group in sampling_metadata.seq_groups:
+                seq_ids = seq_group.seq_ids
+                if (seq_group.is_prompt
+                        and sampling_params.prompt_logprobs is not None):
+                    prefill_len = len(seq_group.prompt_logprob_indices)
+                    prompt_tokens.extend([] for _ in range(prefill_len))
+                    output_tokens.extend([] for _ in range(prefill_len))
+                if seq_group.do_sample:
+                    for seq_id in seq_ids:
+                        seq_data = seq_group.seq_data[seq_id]
+                        prompt_tokens.append(seq_data.prompt_token_ids)
+                        output_tokens.append(seq_data.output_token_ids)
+
         sampling_tensors = SamplingTensors.from_lists(
         sampling_tensors = SamplingTensors.from_lists(
             temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
             temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
             frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
             frequency_penalties, repetition_penalties, tfss, eta_cutoffs,