|
@@ -22,6 +22,11 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
|
# (num_token_ids, num_parent_ids) per sequence group.
|
|
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
|
|
|
|
|
+# There isn't a "safe" temperature range for fp16 logits.
|
|
|
+# This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
|
|
|
+# that this temperature well-uses the fp16 space after the logits are offset.
|
|
|
+_TEMPERATURE_MINIMUM = 2e-5
|
|
|
+
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
|
"""Samples the next tokens from the model's outputs.
|
|
@@ -318,12 +323,16 @@ def _apply_temperatures(
|
|
|
normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
|
|
|
dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
|
|
|
normalized_entropies.pow_(dynatemp_exps))
|
|
|
-
|
|
|
temperatures[dynatemp_mask] = dyn_temp
|
|
|
- temperatures[temperatures <= 0.0] = 1.0
|
|
|
- # Use float32 to apply temp.
|
|
|
- # Use in-place division to avoid creating a new tensor.
|
|
|
- logits = logits.to(torch.float)
|
|
|
+
|
|
|
+ temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
|
|
|
+ temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
|
|
|
+
|
|
|
+ # To prevent saturation of top logits, we shift the range to [-inf, 1]
|
|
|
+ # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
|
|
|
+ # Why mask? So we aren't potentially discarding data in milder temps.
|
|
|
+ low_temps = temperatures < 0.1
|
|
|
+ logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
|
|
|
logits.div_(temperatures.unsqueeze(dim=1))
|
|
|
|
|
|
|