Преглед на файлове

fix: crash in quadratic sampling when batch > 1 (#253)

Co-authored-by: anon <anon@example.org>
anon998 преди 1 година
родител
ревизия
35b9033782
променени са 1 файла, в които са добавени 1 реда и са изтрити 1 реда
  1. 1 1
      aphrodite/modeling/layers/sampler.py

+ 1 - 1
aphrodite/modeling/layers/sampler.py

@@ -416,7 +416,7 @@ def _apply_quadratic_sampling(
     Credits: @kalomaze
     """
     max_logits = logits.max(dim=-1, keepdim=True).values
-    transformed_logits = -(smoothing_factors *
+    transformed_logits = -(smoothing_factors.unsqueeze_(dim=1) *
                            (logits - max_logits).pow(2)) + max_logits
     return transformed_logits