Browse Source

fix: crash in quadratic sampling when batch > 1 (#253)

Co-authored-by: anon <anon@example.org>
anon998 1 year ago
parent
commit
35b9033782
1 changed files with 1 additions and 1 deletions
  1. 1 1
      aphrodite/modeling/layers/sampler.py

+ 1 - 1
aphrodite/modeling/layers/sampler.py

@@ -416,7 +416,7 @@ def _apply_quadratic_sampling(
     Credits: @kalomaze
     """
     max_logits = logits.max(dim=-1, keepdim=True).values
-    transformed_logits = -(smoothing_factors *
+    transformed_logits = -(smoothing_factors.unsqueeze_(dim=1) *
                            (logits - max_logits).pow(2)) + max_logits
     return transformed_logits