|
@@ -93,9 +93,10 @@ class Sampler(nn.Module):
|
|
|
|
|
|
if do_temperatures:
|
|
|
logits = _apply_temperature(logits, sampling_tensors.temperatures,
|
|
|
- sampling_tensors.dynatemp_mins,
|
|
|
- sampling_tensors.dynatemp_maxs,
|
|
|
- sampling_tensors.dynatemp_exps)
|
|
|
+ # sampling_tensors.dynatemp_mins,
|
|
|
+ # sampling_tensors.dynatemp_maxs,
|
|
|
+ # sampling_tensors.dynatemp_exps
|
|
|
+ )
|
|
|
|
|
|
banned_tokens = _get_custom_token_bans(sampling_metadata)
|
|
|
# assert len(banned_tokens) == logits.shape[0]
|
|
@@ -396,29 +397,29 @@ def _apply_typical_sampling(
|
|
|
def _apply_temperature(
|
|
|
logits: torch.Tensor,
|
|
|
temperatures: torch.Tensor,
|
|
|
- dynatemp_mins: torch.Tensor,
|
|
|
- dynatemp_maxs: torch.Tensor,
|
|
|
- dynatemp_exps: torch.Tensor,
|
|
|
+ # dynatemp_mins: torch.Tensor,
|
|
|
+ # dynatemp_maxs: torch.Tensor,
|
|
|
+ # dynatemp_exps: torch.Tensor,
|
|
|
) -> torch.Tensor:
|
|
|
- dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
|
|
|
- dynatemp_mins = dynatemp_mins[dynatemp_mask]
|
|
|
- dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
|
|
|
- dynatemp_exps = dynatemp_exps[dynatemp_mask]
|
|
|
- dynatemp_mins = dynatemp_mins.clamp_(min=0)
|
|
|
-
|
|
|
- dynatemp_logits = logits[dynatemp_mask]
|
|
|
- dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
|
|
|
- dynatemp_probs = dynatemp_shifted_logits.exp()
|
|
|
- dynatemp_entropies = -(dynatemp_probs *
|
|
|
- dynatemp_shifted_logits).nansum(dim=-1)
|
|
|
- dynatemp_max_entropies = torch.log_(
|
|
|
- (dynatemp_logits > float("-inf")).sum(dim=-1).float())
|
|
|
- normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
|
|
|
- dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
|
|
|
- normalized_entropies.pow_(dynatemp_exps))
|
|
|
-
|
|
|
- temperatures[dynatemp_mask] = dyn_temp
|
|
|
- temperatures[temperatures == 0.0] = 1.0
|
|
|
+ # dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
|
|
|
+ # dynatemp_mins = dynatemp_mins[dynatemp_mask]
|
|
|
+ # dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
|
|
|
+ # dynatemp_exps = dynatemp_exps[dynatemp_mask]
|
|
|
+ # dynatemp_mins = dynatemp_mins.clamp_(min=0)
|
|
|
+
|
|
|
+ # dynatemp_logits = logits[dynatemp_mask]
|
|
|
+ # dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
|
|
|
+ # dynatemp_probs = dynatemp_shifted_logits.exp()
|
|
|
+ # dynatemp_entropies = -(dynatemp_probs *
|
|
|
+ # dynatemp_shifted_logits).nansum(dim=-1)
|
|
|
+ # dynatemp_max_entropies = torch.log_(
|
|
|
+ # (dynatemp_logits > float("-inf")).sum(dim=-1).float())
|
|
|
+ # normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
|
|
|
+ # dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
|
|
|
+ # normalized_entropies.pow_(dynatemp_exps))
|
|
|
+
|
|
|
+ # temperatures[dynatemp_mask] = dyn_temp
|
|
|
+ # temperatures[temperatures == 0.0] = 1.0
|
|
|
logits.div_(temperatures.unsqueeze_(dim=1))
|
|
|
return logits
|
|
|
|