|
@@ -374,12 +374,12 @@ def _apply_eta_cutoff(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
eta_cutoffs: List[float],
|
|
eta_cutoffs: List[float],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device)
|
|
|
|
|
|
+ eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device) * 1e-4
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
probs = shifted_logits.exp()
|
|
probs = shifted_logits.exp()
|
|
|
|
|
|
- neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
|
|
|
|
- eps = torch.min(eta * 1e-4, torch.sqrt(eta*1e-4)*torch.exp(neg_entropy))
|
|
|
|
|
|
+ neg_entropy = (probs * shifted_logits).nansum(dim=-1)
|
|
|
|
+ eps = torch.min(eta, torch.sqrt(eta)*torch.exp(neg_entropy)).unsqueeze(dim=1)
|
|
|
|
|
|
eta_mask = probs < eps
|
|
eta_mask = probs < eps
|
|
|
|
|
|
@@ -395,7 +395,7 @@ def _apply_epsilon_cutoff(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
epsilon_cutoffs: List[float],
|
|
epsilon_cutoffs: List[float],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device)
|
|
|
|
|
|
+ eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device).unsqueeze(dim=1)
|
|
probs = logits.softmax(dim=-1)
|
|
probs = logits.softmax(dim=-1)
|
|
|
|
|
|
eps_mask = probs < (eps * 1e-4)
|
|
eps_mask = probs < (eps * 1e-4)
|
|
@@ -421,7 +421,7 @@ def _apply_typical_sampling(
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
_, indices = torch.sort(surprisal_deviations)
|
|
_, indices = torch.sort(surprisal_deviations)
|
|
reordered_probs = probs.gather(-1, indices)
|
|
reordered_probs = probs.gather(-1, indices)
|
|
- typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p
|
|
|
|
|
|
+ typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
|
|
|
|
|
|
min_tokens_to_keep = 1
|
|
min_tokens_to_keep = 1
|
|
# Keep at least min_tokens_to_keep
|
|
# Keep at least min_tokens_to_keep
|