|
@@ -416,7 +416,7 @@ def _apply_quadratic_sampling(
|
|
|
Credits: @kalomaze
|
|
|
"""
|
|
|
max_logits = logits.max(dim=-1, keepdim=True).values
|
|
|
- transformed_logits = -(smoothing_factors *
|
|
|
+ transformed_logits = -(smoothing_factors.unsqueeze_(dim=1) *
|
|
|
(logits - max_logits).pow(2)) + max_logits
|
|
|
return transformed_logits
|
|
|
|