|
@@ -303,7 +303,7 @@ def _apply_tfs(
|
|
|
z = torch.tensor(tfss, dtype=logits.dtype, device=logits.device)
|
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
|
|
d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
|
|
|
- normalized_d2 = d2 / torch.sum(d2, dim=-1)
|
|
|
+ normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
|
|
|
curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
|
|
|
|
|
|
tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
|