Răsfoiți Sursa

fix: empty sampler output when temperature is too low (#709)

AlpinDale 6 luni în urmă
părinte
comite
198029295c

+ 7 - 0
aphrodite/common/sampling_params.py

@@ -11,6 +11,7 @@ from pydantic import Field
 from typing_extensions import Annotated
 
 _SAMPLING_EPS = 1e-5
+_MAX_TEMP = 1e-2
 
 APHRODITE_NO_DEPRECATION_WARNING = bool(
     int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
@@ -185,6 +186,12 @@ class SamplingParams:
         self.presence_penalty = presence_penalty
         self.frequency_penalty = frequency_penalty
         self.repetition_penalty = repetition_penalty
+        if 0 < temperature < _MAX_TEMP:
+            logger.warning(
+                f"temperature {temperature} is less than {_MAX_TEMP}, "
+                f"which may cause numerical errors (NaN or Inf) in tensors. "
+                f"We have capped the temperature to {_MAX_TEMP}.")
+            temperature = min(temperature, _MAX_TEMP)
         self.temperature = temperature
         self.temperature_last = temperature_last
         self.top_p = top_p

+ 4 - 0
aphrodite/modeling/layers/sampler.py

@@ -135,7 +135,9 @@ class Sampler(nn.Module):
 
         # Apply temperature scaling if not doing temp_last.
         if not do_temp_last:
+            # Use float32 to apply temp.
             # Use in-place division to avoid creating a new tensor.
+            logits = logits.to(torch.float)
             logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
 
         if do_top_p_top_k:
@@ -168,7 +170,9 @@ class Sampler(nn.Module):
                 sampling_tensors.smoothing_curves)
 
         if do_temp_last:
+            # Use float32 to apply temp.
             # Use in-place division to avoid creating a new tensor.
+            logits = logits.to(torch.float)
             logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
 
         # banned_tokens = _get_custom_token_bans(sampling_metadata)