AlpinDale 5 月之前
父節點
當前提交
fd50b83b6f
共有 1 個文件被更改,包括 14 次插入13 次删除
  1. 14 13
      aphrodite/modeling/layers/sampler.py

+ 14 - 13
aphrodite/modeling/layers/sampler.py

@@ -645,36 +645,37 @@ def _apply_jsd_sampling(
     logits: torch.Tensor,
     jsd_thresholds: torch.Tensor,
 ) -> torch.Tensor:
-    """Applies Jensen-Shannon Distance-based sampling to filter tokens.
+    """Applies Jensen-Shannon Divergence-based sampling to filter tokens.
 
     Args:
         logits: Tensor of shape (batch_size, vocab_size) containing the logits.
         jsd_thresholds: Tensor of shape (batch_size,) with the JSD threshold for each sequence.
 
     Returns:
-        Modified logits tensor with tokens beyond the JSD threshold masked out.
+        Modified logits tensor with tokens exceeding the JSD threshold masked out.
     """
     # Compute the probability distribution from logits
-    probs = torch.softmax(logits, dim=-1)
+    probs = torch.softmax(logits, dim=-1)  # Shape: (batch_size, vocab_size)
 
     # Create a uniform distribution
-    uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1))
+    uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1))  # Shape: (batch_size, vocab_size)
 
     # Compute the average distribution
-    average_probs = 0.5 * (probs + uniform_probs)
+    average_probs = 0.5 * (probs + uniform_probs)  # Shape: (batch_size, vocab_size)
 
-    # Compute the Kullback-Leibler divergences
-    kl_probs = torch.sum(probs * (torch.log(probs + 1e-8) - torch.log(average_probs + 1e-8)), dim=-1)
-    kl_uniform = torch.sum(uniform_probs * (torch.log(uniform_probs + 1e-8) - torch.log(average_probs + 1e-8)), dim=-1)
+    # Compute per-token KL divergences
+    kl_probs = probs * (torch.log(probs + 1e-8) - torch.log(average_probs + 1e-8))  # Shape: (batch_size, vocab_size)
+    kl_uniform = uniform_probs * (torch.log(uniform_probs + 1e-8) - torch.log(average_probs + 1e-8))  # Shape: (batch_size, vocab_size)
 
-    # Calculate the Jensen-Shannon Distance
-    jsd = 0.5 * (kl_probs + kl_uniform)
+    # Compute per-token JSD
+    jsd = 0.5 * (kl_probs + kl_uniform)  # Shape: (batch_size, vocab_size)
 
-    # Create a mask for tokens where JSD is less than the threshold
-    jsd_mask = jsd.unsqueeze(-1) < jsd_thresholds.unsqueeze(-1)
+    # Compare per-token JSD with threshold
+    jsd_thresholds_expanded = jsd_thresholds.unsqueeze(-1)  # Shape: (batch_size, 1)
+    jsd_mask = jsd >= jsd_thresholds_expanded  # Shape: (batch_size, vocab_size)
 
     # Mask out tokens where the JSD exceeds the threshold
-    logits = logits.masked_fill(~jsd_mask, -float("inf"))
+    logits = logits.masked_fill(jsd_mask, -float("inf"))
     return logits