Переглянути джерело

fix eta,eps and typical for parallel requests

Stefan Gligorijevic 1 рік тому
батько
коміт
8a6c9f5cbd
1 змінених файлів з 5 додано та 5 видалено
  1. 5 5
      aphrodite/modeling/layers/sampler.py

+ 5 - 5
aphrodite/modeling/layers/sampler.py

@@ -374,12 +374,12 @@ def _apply_eta_cutoff(
     logits: torch.Tensor,
     logits: torch.Tensor,
     eta_cutoffs: List[float],
     eta_cutoffs: List[float],
 ) -> torch.Tensor:
 ) -> torch.Tensor:
-    eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device)
+    eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device) * 1e-4
     shifted_logits = torch.log_softmax(logits, dim=-1)
     shifted_logits = torch.log_softmax(logits, dim=-1)
     probs = shifted_logits.exp()
     probs = shifted_logits.exp()
 
 
-    neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
-    eps = torch.min(eta * 1e-4, torch.sqrt(eta*1e-4)*torch.exp(neg_entropy))
+    neg_entropy = (probs * shifted_logits).nansum(dim=-1)
+    eps = torch.min(eta, torch.sqrt(eta)*torch.exp(neg_entropy)).unsqueeze(dim=1)
 
 
     eta_mask = probs < eps
     eta_mask = probs < eps
 
 
@@ -395,7 +395,7 @@ def _apply_epsilon_cutoff(
     logits: torch.Tensor,
     logits: torch.Tensor,
     epsilon_cutoffs: List[float],
     epsilon_cutoffs: List[float],
 ) -> torch.Tensor:
 ) -> torch.Tensor:
-    eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device)
+    eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device).unsqueeze(dim=1)
     probs = logits.softmax(dim=-1)
     probs = logits.softmax(dim=-1)
 
 
     eps_mask = probs < (eps * 1e-4)
     eps_mask = probs < (eps * 1e-4)
@@ -421,7 +421,7 @@ def _apply_typical_sampling(
     surprisal_deviations = (neg_entropy - shifted_logits).abs()
     surprisal_deviations = (neg_entropy - shifted_logits).abs()
     _, indices = torch.sort(surprisal_deviations)
     _, indices = torch.sort(surprisal_deviations)
     reordered_probs = probs.gather(-1, indices)
     reordered_probs = probs.gather(-1, indices)
-    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p
+    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
     
     
     min_tokens_to_keep = 1
     min_tokens_to_keep = 1
     # Keep at least min_tokens_to_keep
     # Keep at least min_tokens_to_keep