Browse Source

guard against nan temperature from dynatemp (or anywhere else).

50h100a 3 months ago
parent
commit
273c61d406
1 changed files with 7 additions and 4 deletions
  1. 7 4
      aphrodite/modeling/layers/sampler.py

+ 7 - 4
aphrodite/modeling/layers/sampler.py

@@ -22,6 +22,11 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
 # (num_token_ids, num_parent_ids) per sequence group.
 SampleResultType = List[Tuple[List[int], List[int]]]
 
+# There isn't a "safe" temperature range for fp16 logits.
+# This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
+# that this temperature well-uses the fp16 space after the logits are offset.
+_TEMPERATURE_MINIMUM = 2e-5
+
 
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
@@ -320,10 +325,8 @@ def _apply_temperatures(
                 normalized_entropies.pow_(dynatemp_exps))
     temperatures[dynatemp_mask] = dyn_temp
   
-    # There isn't a "safe" temperature range for fp16 logits.
-    # This value was chosen because 1/2e-5 is just under the 65k fp16 max,
-    # meaning that this temperature well-uses the fp16 space after the offset.
-    temperatures[temperatures <= 2e-5] = 2e-5
+    temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
+    temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
   
     # To prevent saturation of top logits, we shift the range to [-inf, 1]
     # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.