瀏覽代碼

Misc fixes in eta, eps, and typical

Stefan Gligorijevic 1 年之前
父節點
當前提交
99f76323ad
共有 1 個文件被更改,包括 11 次插入10 次删除
  1. 11 10
      aphrodite/modeling/layers/sampler.py

+ 11 - 10
aphrodite/modeling/layers/sampler.py

@@ -384,8 +384,8 @@ def _apply_eta_cutoff(
     eta_mask = probs < eps
 
     if(torch.all(eta_mask)): # guard against nulling out all the logits
-        _, max_idx = torch.max(probs, dim=-1)
-        eta_mask[max_idx] = False
+        topk_prob, _ = torch.max(probs, dim=-1)
+        eta_mask = probs < topk_prob
 
     logits[eta_mask] = -float("inf")
     return logits
@@ -395,14 +395,14 @@ def _apply_epsilon_cutoff(
     logits: torch.Tensor,
     epsilon_cutoffs: List[float],
 ) -> torch.Tensor:
-    probs = torch.softmax(logits, dim=-1)
     eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device)
+    probs = logits.softmax(dim=-1)
 
-    eps_mask = probs < eps
+    eps_mask = probs < (eps * 1e-4)
 
     if(torch.all(eps_mask)): # guard against nulling out all the logits
-        _, max_idx = torch.max(probs, dim=-1)
-        eps_mask[max_idx] = False
+        topk_prob, _ = torch.max(probs, dim=-1)
+        eps_mask = probs < topk_prob
 
     logits[eps_mask] = -float("inf")
     return logits
@@ -412,15 +412,16 @@ def _apply_typical_sampling(
     logits: torch.Tensor,
     typical_ps: List[float],
 ) -> torch.Tensor:
+    typ_p = torch.tensor(typical_ps, dtype=logits.dtype, device=logits.device)
     shifted_logits = torch.log_softmax(logits, dim=-1)
     probs = shifted_logits.exp()
 
     neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
 
     surprisal_deviations = (neg_entropy - shifted_logits).abs()
-    _, indices = torch.sort(surprisal_deviation)
-    reordered_probs = probs.gather(-1, sorted_indices)
-    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_ps
+    _, indices = torch.sort(surprisal_deviations)
+    reordered_probs = probs.gather(-1, indices)
+    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p
     
     min_tokens_to_keep = 1
     # Keep at least min_tokens_to_keep
@@ -667,4 +668,4 @@ def _sample(
         assert sample_idx == num_tokens
         category_start_idx += num_tokens
 
-    return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
+    return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]