Ver Fonte

Merge pull request #814 from PygmalionAI/50h100a-temp-fix

fix: temperature issues
50h100a há 3 meses atrás
pai
commit
a5346b2ea5
1 ficheiros alterados com 14 adições e 5 exclusões
  1. 14 5
      aphrodite/modeling/layers/sampler.py

+ 14 - 5
aphrodite/modeling/layers/sampler.py

@@ -22,6 +22,11 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
 # (num_token_ids, num_parent_ids) per sequence group.
 SampleResultType = List[Tuple[List[int], List[int]]]
 
+# There isn't a "safe" temperature range for fp16 logits.
+# This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
+# that this temperature well-uses the fp16 space after the logits are offset.
+_TEMPERATURE_MINIMUM = 2e-5
+
 
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
@@ -318,12 +323,16 @@ def _apply_temperatures(
     normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
     dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
                 normalized_entropies.pow_(dynatemp_exps))
-
     temperatures[dynatemp_mask] = dyn_temp
-    temperatures[temperatures <= 0.0] = 1.0
-    # Use float32 to apply temp.
-    # Use in-place division to avoid creating a new tensor.
-    logits = logits.to(torch.float)
+  
+    temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
+    temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
+  
+    # To prevent saturation of top logits, we shift the range to [-inf, 1]
+    # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
+    # Why mask? So we aren't potentially discarding data in milder temps.
+    low_temps = temperatures < 0.1
+    logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
     logits.div_(temperatures.unsqueeze(dim=1))