Stefan Gligorijevic 1 年之前
父節點
當前提交
5649694aaa
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      aphrodite/modeling/layers/sampler.py

+ 1 - 1
aphrodite/modeling/layers/sampler.py

@@ -303,7 +303,7 @@ def _apply_tfs(
     z = torch.tensor(tfss, dtype=logits.dtype, device=logits.device)
     logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
     d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
-    normalized_d2 = d2 / torch.sum(d2, dim=-1)
+    normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
     curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
 
     tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)