|
@@ -11,6 +11,7 @@ from pydantic import Field
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
+_MAX_TEMP = 1e-2
|
|
|
|
|
|
APHRODITE_NO_DEPRECATION_WARNING = bool(
|
|
|
int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
|
|
@@ -185,6 +186,12 @@ class SamplingParams:
|
|
|
self.presence_penalty = presence_penalty
|
|
|
self.frequency_penalty = frequency_penalty
|
|
|
self.repetition_penalty = repetition_penalty
|
|
|
+ if 0 < temperature < _MAX_TEMP:
|
|
|
+ logger.warning(
|
|
|
+ f"temperature {temperature} is less than {_MAX_TEMP}, "
|
|
|
+ f"which may cause numerical errors (NaN or Inf) in tensors. "
|
|
|
+ f"We have capped the temperature to {_MAX_TEMP}.")
|
|
|
+ temperature = min(temperature, _MAX_TEMP)
|
|
|
self.temperature = temperature
|
|
|
self.temperature_last = temperature_last
|
|
|
self.top_p = top_p
|