|
@@ -384,8 +384,8 @@ def _apply_eta_cutoff(
|
|
|
eta_mask = probs < eps
|
|
|
|
|
|
if(torch.all(eta_mask)): # guard against nulling out all the logits
|
|
|
- _, max_idx = torch.max(probs, dim=-1)
|
|
|
- eta_mask[max_idx] = False
|
|
|
+ topk_prob, _ = torch.max(probs, dim=-1)
|
|
|
+ eta_mask = probs < topk_prob
|
|
|
|
|
|
logits[eta_mask] = -float("inf")
|
|
|
return logits
|
|
@@ -395,14 +395,14 @@ def _apply_epsilon_cutoff(
|
|
|
logits: torch.Tensor,
|
|
|
epsilon_cutoffs: List[float],
|
|
|
) -> torch.Tensor:
|
|
|
- probs = torch.softmax(logits, dim=-1)
|
|
|
eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device)
|
|
|
+ probs = logits.softmax(dim=-1)
|
|
|
|
|
|
- eps_mask = probs < eps
|
|
|
+ eps_mask = probs < (eps * 1e-4)
|
|
|
|
|
|
if(torch.all(eps_mask)): # guard against nulling out all the logits
|
|
|
- _, max_idx = torch.max(probs, dim=-1)
|
|
|
- eps_mask[max_idx] = False
|
|
|
+ topk_prob, _ = torch.max(probs, dim=-1)
|
|
|
+ eps_mask = probs < topk_prob
|
|
|
|
|
|
logits[eps_mask] = -float("inf")
|
|
|
return logits
|
|
@@ -412,15 +412,16 @@ def _apply_typical_sampling(
|
|
|
logits: torch.Tensor,
|
|
|
typical_ps: List[float],
|
|
|
) -> torch.Tensor:
|
|
|
+ typ_p = torch.tensor(typical_ps, dtype=logits.dtype, device=logits.device)
|
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
|
probs = shifted_logits.exp()
|
|
|
|
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
|
|
|
|
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
|
- _, indices = torch.sort(surprisal_deviation)
|
|
|
- reordered_probs = probs.gather(-1, sorted_indices)
|
|
|
- typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_ps
|
|
|
+ _, indices = torch.sort(surprisal_deviations)
|
|
|
+ reordered_probs = probs.gather(-1, indices)
|
|
|
+ typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p
|
|
|
|
|
|
min_tokens_to_keep = 1
|
|
|
# Keep at least min_tokens_to_keep
|
|
@@ -667,4 +668,4 @@ def _sample(
|
|
|
assert sample_idx == num_tokens
|
|
|
category_start_idx += num_tokens
|
|
|
|
|
|
- return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
|
|
+ return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|