Prechádzať zdrojové kódy

Merge pull request #766 from 50h100a/simplifymeta

Simplify construction of sampling_metadata
50h100a 5 mesiacov pred
rodič
commit
5795565e4b
1 zmenil súbory, kde vykonal 61 pridanie a 105 odobranie
  1. 61 105
      aphrodite/modeling/sampling_metadata.py

+ 61 - 105
aphrodite/modeling/sampling_metadata.py

@@ -431,7 +431,7 @@ class SamplingTensors:
         smoothing_curves: List[float] = []
         xtc_thresholds: List[float] = []
         xtc_probabilities: List[float] = []
-        sampling_seeds: List[int] = []
+        sampling_seeds: List[List[int]] = []
         sample_indices: List[int] = []
         do_penalties = False
         do_temperatures = False
@@ -456,124 +456,79 @@ class SamplingTensors:
         assert sampling_metadata.seq_groups is not None
         for seq_group in sampling_metadata.seq_groups:
             seq_ids = seq_group.seq_ids
-            sampling_params = seq_group.sampling_params
-            temperature = sampling_params.temperature
-            dynatemp_min = sampling_params.dynatemp_min
-            dynatemp_max = sampling_params.dynatemp_max
-            dynatemp_exp = sampling_params.dynatemp_exponent
-            temperature_last = sampling_params.temperature_last
-            p = sampling_params.presence_penalty
-            f = sampling_params.frequency_penalty
-            r = sampling_params.repetition_penalty
-            top_p = sampling_params.top_p
-            top_a = sampling_params.top_a
-            min_p = sampling_params.min_p
-            tfs = sampling_params.tfs
-            eta_cutoff = sampling_params.eta_cutoff
-            epsilon_cutoff = sampling_params.epsilon_cutoff
-            typical_p = sampling_params.typical_p
-            smoothing_factor = sampling_params.smoothing_factor
-            smoothing_curve = sampling_params.smoothing_curve
-            xtc_threshold = sampling_params.xtc_threshold
-            xtc_probability = sampling_params.xtc_probability
+            params = seq_group.sampling_params
 
             # k should not be greater than the vocab size.
-            top_k = min(sampling_params.top_k, vocab_size)
+            top_k = min(params.top_k, vocab_size)
             top_k = vocab_size if top_k == -1 else top_k
+
+            temperature = params.temperature
             if temperature < _SAMPLING_EPS:
                 # NOTE: Zero temperature means deterministic sampling
                 # (i.e., greedy sampling or beam search).
                 # Set the temperature to 1 to avoid division by zero.
                 temperature = 1.0
-            if not do_temperatures and temperature != 1.0:
-                do_temperatures = True
-            if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
-                                       or top_k != vocab_size):
-                do_top_p_top_k = True
-            if do_top_as is False and top_a > 0.0:
-                do_top_as = True
-            if not do_min_p and min_p > _SAMPLING_EPS:
-                do_min_p = True
-            if not do_penalties and (abs(p) >= _SAMPLING_EPS
-                                     or abs(f) >= _SAMPLING_EPS
-                                     or abs(r - 1.0) >= _SAMPLING_EPS):
-                do_penalties = True
-            if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
-                do_tfss = True
-            if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
-                do_eta_cutoffs = True
-            if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
-                do_epsilon_cutoffs = True
-            if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
-                do_typical_ps = True
-            if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
-                                          or smoothing_curve > 1.0):
-                do_quadratic = True
-            if do_xtc is False and xtc_probability > _SAMPLING_EPS:
-                do_xtc = True
-            if do_temp_last is False and temperature_last:
-                do_temp_last = True
+
+            do_temperatures |= (temperature != 1.0 or
+                                params.dynatemp_min > _SAMPLING_EPS or
+                                params.dynatemp_max > _SAMPLING_EPS)
+            do_top_p_top_k |= (params.top_p < 1.0 - _SAMPLING_EPS or
+                               top_k != vocab_size)
+            do_top_as |= params.top_a > 0.0
+            do_min_p |= params.min_p > _SAMPLING_EPS
+            do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
+                             abs(params.frequency_penalty) >= _SAMPLING_EPS or
+                             params.repetition_penalty > 1.0)
+            do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
+            do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
+            do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
+            do_typical_ps |= params.typical_p < 1.0 - _SAMPLING_EPS
+            do_quadratic |= (params.smoothing_factor > _SAMPLING_EPS or
+                             params.smoothing_curve > 1.0)
+            do_xtc |= params.xtc_probability > _SAMPLING_EPS
+            do_temp_last |= params.temperature_last
 
             is_prompt = seq_group.is_prompt
-            if (is_prompt and sampling_params.prompt_logprobs is not None):
-                # For tokens in the prompt that we only need to get
-                # their logprobs
-                query_len = seq_group.query_len
-                assert query_len is not None
-                prefill_len = len(seq_group.prompt_logprob_indices)
-                temperatures += [temperature] * prefill_len
-                dynatemp_mins += [dynatemp_min] * prefill_len
-                dynatemp_maxs += [dynatemp_max] * prefill_len
-                dynatemp_exps += [dynatemp_exp] * prefill_len
-                temperature_lasts += [temperature_last] * prefill_len
-                top_ps += [top_p] * prefill_len
-                top_ks += [top_k] * prefill_len
-                top_as += [top_a] * prefill_len
-                min_ps += [min_p] * prefill_len
-                presence_penalties += [0] * prefill_len
-                frequency_penalties += [0] * prefill_len
-                repetition_penalties += [1] * prefill_len
-                tfss += [1] * prefill_len
-                eta_cutoffs += [0] * prefill_len
-                epsilon_cutoffs += [0] * prefill_len
-                typical_ps += [1] * prefill_len
-                smoothing_factors += [smoothing_factor] * prefill_len
-                smoothing_curves += [smoothing_curve] * prefill_len
-                xtc_thresholds += [xtc_threshold] * prefill_len
-                xtc_probabilities += [xtc_probability] * prefill_len
+            wants_prompt_logprobs = params.prompt_logprobs is not None
+
+            n_seqs = 0
+            if seq_group.is_prompt and wants_prompt_logprobs:
+                assert seq_group.query_len is not None
+                n_seqs += len(seq_group.prompt_logprob_indices)
 
             if seq_group.do_sample:
-                sample_lens = len(seq_group.sample_indices)
-                assert sample_lens == len(seq_ids)
-                temperatures += [temperature] * len(seq_ids)
-                dynatemp_mins += [dynatemp_min] * len(seq_ids)
-                dynatemp_maxs += [dynatemp_max] * len(seq_ids)
-                dynatemp_exps += [dynatemp_exp] * len(seq_ids)
-                temperature_lasts += [temperature_last] * len(seq_ids)
-                top_ps += [top_p] * len(seq_ids)
-                top_ks += [top_k] * len(seq_ids)
-                top_as += [top_a] * len(seq_ids)
-                min_ps += [min_p] * len(seq_ids)
-                presence_penalties += [p] * len(seq_ids)
-                frequency_penalties += [f] * len(seq_ids)
-                repetition_penalties += [r] * len(seq_ids)
-                tfss += [tfs] * len(seq_ids)
-                eta_cutoffs += [eta_cutoff] * len(seq_ids)
-                epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
-                typical_ps += [typical_p] * len(seq_ids)
-                smoothing_factors += [smoothing_factor] * len(seq_ids)
-                smoothing_curves += [smoothing_curve] * len(seq_ids)
-                xtc_thresholds += [xtc_threshold] * len(seq_ids)
-                xtc_probabilities += [xtc_probability] * len(seq_ids)
+                assert len(seq_group.sample_indices) == len(seq_ids)
+                n_seqs += len(seq_ids)
+
+            temperatures += [temperature] * n_seqs
+            dynatemp_mins += [params.dynatemp_min] * n_seqs
+            dynatemp_maxs += [params.dynatemp_max] * n_seqs
+            dynatemp_exps += [params.dynatemp_exponent] * n_seqs
+            temperature_lasts += [params.temperature_last] * n_seqs
+            top_ps += [params.top_p] * n_seqs
+            top_ks += [top_k] * n_seqs
+            top_as += [params.top_a] * n_seqs
+            min_ps += [params.min_p] * n_seqs
+            presence_penalties += [params.presence_penalty] * n_seqs
+            frequency_penalties += [params.frequency_penalty] * n_seqs
+            repetition_penalties += [params.repetition_penalty] * n_seqs
+            tfss += [params.tfs] * n_seqs
+            eta_cutoffs += [params.eta_cutoff] * n_seqs
+            epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
+            typical_ps += [params.typical_p] * n_seqs
+            smoothing_factors += [params.smoothing_factor] * n_seqs
+            smoothing_curves += [params.smoothing_curve] * n_seqs
+            xtc_thresholds += [params.xtc_threshold] * n_seqs
+            xtc_probabilities += [params.xtc_probability] * n_seqs
 
             if _USE_TRITON_SAMPLER:
                 if is_prompt:
-                    prompt_best_of.append(sampling_params.best_of)
+                    prompt_best_of.append(params.best_of)
                     query_len = seq_group.query_len
                     assert query_len is not None
 
-                seed = sampling_params.seed
-                is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
+                seed = params.seed
+                is_greedy = params.sampling_type == SamplingType.GREEDY
 
                 for seq_id in seq_ids:
                     seq_data = seq_group.seq_data[seq_id]
@@ -592,7 +547,7 @@ class SamplingTensors:
             for seq_group in sampling_metadata.seq_groups:
                 seq_ids = seq_group.seq_ids
                 if (seq_group.is_prompt
-                        and sampling_params.prompt_logprobs is not None):
+                        and params.prompt_logprobs is not None):
                     prefill_len = len(seq_group.prompt_logprob_indices)
                     prompt_tokens.extend(
                         array('l') for _ in range(prefill_len))
@@ -609,7 +564,7 @@ class SamplingTensors:
             temperature_lasts, top_ps, top_ks, top_as, min_ps,
             presence_penalties, frequency_penalties, repetition_penalties,
             tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
-            smoothing_curves, xtc_thresholds, xtc_probabilities,sampling_seeds,
+            smoothing_curves, xtc_thresholds, xtc_probabilities, sampling_seeds,
             sample_indices, prompt_tokens, output_tokens, vocab_size,
             extra_seeds_to_generate, device, dtype)
         return (sampling_tensors, do_penalties, do_temperatures,
@@ -628,7 +583,8 @@ class SamplingTensors:
                    eta_cutoffs: List[float], epsilon_cutoffs: List[float],
                    typical_ps: List[float], smoothing_factors: List[float],
                    smoothing_curves: List[float], xtc_thresholds: List[float],
-                   xtc_probabilities: List[float], sampling_seeds: List[int],
+                   xtc_probabilities: List[float],
+                   sampling_seeds: List[List[int]],
                    sample_indices: List[int], prompt_tokens: List[array],
                    output_tokens: List[array], vocab_size: int,
                    extra_seeds_to_generate: int, device: torch.device,
@@ -827,7 +783,7 @@ class SamplingTensors:
 
     @staticmethod
     def _get_sequence_seeds(
-        seed: int,
+        seed: int|None,
         *extra_entropy: int,
         seeds_to_generate: int,
         is_greedy: bool,