|
@@ -418,7 +418,7 @@ def _apply_typical_sampling(
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
|
|
|
|
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
- _, indices = torch.sort(entropy_deviation)
|
|
|
|
|
|
+ _, indices = torch.sort(surprisal_deviation)
|
|
reordered_probs = probs.gather(-1, sorted_indices)
|
|
reordered_probs = probs.gather(-1, sorted_indices)
|
|
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_ps
|
|
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_ps
|
|
|
|
|
|
@@ -667,4 +667,4 @@ def _sample(
|
|
assert sample_idx == num_tokens
|
|
assert sample_idx == num_tokens
|
|
category_start_idx += num_tokens
|
|
category_start_idx += num_tokens
|
|
|
|
|
|
- return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
|
|
|
|
|
+ return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|