Explorar o código

remove special-case values for prompt sequences

50h100a hai 5 meses
pai
achega
ed87b5dd32
Modificáronse 1 ficheiros con 14 adicións e 28 borrados
  1. 14 28
      aphrodite/modeling/sampling_metadata.py

+ 14 - 28
aphrodite/modeling/sampling_metadata.py

@@ -491,18 +491,15 @@ class SamplingTensors:
             is_prompt = seq_group.is_prompt
             wants_prompt_logprobs = params.prompt_logprobs is not None
 
-            n_prompt_seqs, n_sample_seqs = 0, 0
+            n_seqs = 0
             if seq_group.is_prompt and wants_prompt_logprobs:
                 assert seq_group.query_len is not None
-                n_prompt_seqs = len(seq_group.prompt_logprob_indices)
+                n_seqs += len(seq_group.prompt_logprob_indices)
 
             if seq_group.do_sample:
                 assert len(seq_group.sample_indices) == len(seq_ids)
-                n_sample_seqs = len(seq_ids)
+                n_seqs += len(seq_ids)
 
-            n_seqs = n_prompt_seqs + n_sample_seqs
-
-            # These parameters apply to ALL sequences
             temperatures += [temperature] * n_seqs
             dynatemp_mins += [params.dynatemp_min] * n_seqs
             dynatemp_maxs += [params.dynatemp_max] * n_seqs
@@ -512,28 +509,17 @@ class SamplingTensors:
             top_ks += [top_k] * n_seqs
             top_as += [params.top_a] * n_seqs
             min_ps += [params.min_p] * n_seqs
-
-            # These parameters have a default value for prompt tokens
-            presence_penalties += [0] * n_prompt_seqs
-            presence_penalties += [params.presence_penalty] * n_sample_seqs
-
-            frequency_penalties += [0] * n_prompt_seqs
-            frequency_penalties += [params.frequency_penalty] * n_sample_seqs
-
-            repetition_penalties += [1] * n_prompt_seqs
-            repetition_penalties += [params.repetition_penalty] * n_sample_seqs
-
-            tfss += [1] * n_prompt_seqs
-            tfss += [params.tfs] * n_sample_seqs
-
-            eta_cutoffs += [0] * n_prompt_seqs
-            eta_cutoffs += [params.eta_cutoff] * n_sample_seqs
-
-            epsilon_cutoffs += [0] * n_prompt_seqs
-            epsilon_cutoffs += [params.epsilon_cutoff] * n_sample_seqs
-
-            typical_ps += [1] * n_prompt_seqs
-            typical_ps += [params.typical_p] * n_sample_seqs
+            presence_penalties += [params.presence_penalty] * n_seqs
+            frequency_penalties += [params.frequency_penalty] * n_seqs
+            repetition_penalties += [params.repetition_penalty] * n_seqs
+            tfss += [params.tfs] * n_seqs
+            eta_cutoffs += [params.eta_cutoff] * n_seqs
+            epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
+            typical_ps += [params.typical_p] * n_seqs
+            smoothing_factors += [params.smoothing_factor] * n_seqs
+            smoothing_curves += [params.smoothing_curve] * n_seqs
+            xtc_thresholds += [params.xtc_threshold] * n_seqs
+            xtc_probabilities += [params.xtc_probability] * n_seqs
 
             if _USE_TRITON_SAMPLER:
                 if is_prompt: