Pārlūkot izejas kodu

fix: zero token output due to temperature bias (#243)

eliminate redundant loop code.
50h100a 1 gadu atpakaļ
vecāks
revīzija
f619c96c79
1 mainītis faili ar 6 papildinājumiem un 10 dzēšanām
  1. 6 10
      aphrodite/modeling/layers/sampler.py

+ 6 - 10
aphrodite/modeling/layers/sampler.py

@@ -400,6 +400,7 @@ def _apply_temperature(
                 normalized_entropies.pow_(dynatemp_exps))
 
     temperatures[dynatemp_mask] = dyn_temp
+    temperatures[temperatures == 0.0] = 1.0
     logits.div_(temperatures.unsqueeze_(dim=1))
     return logits
 
@@ -556,12 +557,10 @@ def _sample(
     sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
     sample_metadata = {}
 
-    # Counterintiutively, having two loops here is actually faster.
+    # Counterintuitively, having two loops here is actually faster.
     # The first loop can run without waiting on GPU<->CPU sync.
-    for sampling_type in SamplingType:
-        sample_indices = categorized_sample_indices[sampling_type]
-        num_tokens = len(sample_indices)
-        if num_tokens == 0:
+    for sampling_type, sample_indices in categorized_sample_indices.items():
+        if len(sample_indices) == 0:
             continue
         seq_group_ids = categorized_seq_group_ids[sampling_type]
         seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
@@ -585,11 +584,8 @@ def _sample(
 
     # GPU<->CPU sync happens in the loop below.
 
-    for sampling_type in SamplingType:
-        if sampling_type not in sample_metadata:
-            continue
-        seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
-            sampling_type]
+    for sampling_type, metadata in sample_metadata.items():
+        seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
         if sampling_type == SamplingType.GREEDY:
             sample_results = _greedy_sample(seq_groups, greedy_samples)
         elif sampling_type == SamplingType.RANDOM: