1
0
Эх сурвалжийг харах

temporarily disable dynatemp

AlpinDale 8 сар өмнө
parent
commit
772b4a4504

+ 26 - 25
aphrodite/modeling/layers/sampler.py

@@ -93,9 +93,10 @@ class Sampler(nn.Module):
 
         if do_temperatures:
             logits = _apply_temperature(logits, sampling_tensors.temperatures,
-                                        sampling_tensors.dynatemp_mins,
-                                        sampling_tensors.dynatemp_maxs,
-                                        sampling_tensors.dynatemp_exps)
+                                        # sampling_tensors.dynatemp_mins,
+                                        # sampling_tensors.dynatemp_maxs,
+                                        # sampling_tensors.dynatemp_exps
+                                        )
 
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         # assert len(banned_tokens) == logits.shape[0]
@@ -396,29 +397,29 @@ def _apply_typical_sampling(
 def _apply_temperature(
     logits: torch.Tensor,
     temperatures: torch.Tensor,
-    dynatemp_mins: torch.Tensor,
-    dynatemp_maxs: torch.Tensor,
-    dynatemp_exps: torch.Tensor,
+    # dynatemp_mins: torch.Tensor,
+    # dynatemp_maxs: torch.Tensor,
+    # dynatemp_exps: torch.Tensor,
 ) -> torch.Tensor:
-    dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
-    dynatemp_mins = dynatemp_mins[dynatemp_mask]
-    dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
-    dynatemp_exps = dynatemp_exps[dynatemp_mask]
-    dynatemp_mins = dynatemp_mins.clamp_(min=0)
-
-    dynatemp_logits = logits[dynatemp_mask]
-    dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
-    dynatemp_probs = dynatemp_shifted_logits.exp()
-    dynatemp_entropies = -(dynatemp_probs *
-                           dynatemp_shifted_logits).nansum(dim=-1)
-    dynatemp_max_entropies = torch.log_(
-        (dynatemp_logits > float("-inf")).sum(dim=-1).float())
-    normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
-    dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
-                normalized_entropies.pow_(dynatemp_exps))
-
-    temperatures[dynatemp_mask] = dyn_temp
-    temperatures[temperatures == 0.0] = 1.0
+    # dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
+    # dynatemp_mins = dynatemp_mins[dynatemp_mask]
+    # dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
+    # dynatemp_exps = dynatemp_exps[dynatemp_mask]
+    # dynatemp_mins = dynatemp_mins.clamp_(min=0)
+
+    # dynatemp_logits = logits[dynatemp_mask]
+    # dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
+    # dynatemp_probs = dynatemp_shifted_logits.exp()
+    # dynatemp_entropies = -(dynatemp_probs *
+    #                        dynatemp_shifted_logits).nansum(dim=-1)
+    # dynatemp_max_entropies = torch.log_(
+    #     (dynatemp_logits > float("-inf")).sum(dim=-1).float())
+    # normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
+    # dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
+    #             normalized_entropies.pow_(dynatemp_exps))
+
+    # temperatures[dynatemp_mask] = dyn_temp
+    # temperatures[temperatures == 0.0] = 1.0
     logits.div_(temperatures.unsqueeze_(dim=1))
     return logits