Stefan Gligorijevic 1 gadu atpakaļ
vecāks
revīzija
5649694aaa
1 mainītis faili ar 1 papildinājumiem un 1 dzēšanām
  1. 1 1
      aphrodite/modeling/layers/sampler.py

+ 1 - 1
aphrodite/modeling/layers/sampler.py

@@ -303,7 +303,7 @@ def _apply_tfs(
     z = torch.tensor(tfss, dtype=logits.dtype, device=logits.device)
     logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
     d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
-    normalized_d2 = d2 / torch.sum(d2, dim=-1)
+    normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
     curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
 
     tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)