|
@@ -645,36 +645,37 @@ def _apply_jsd_sampling(
|
|
|
logits: torch.Tensor,
|
|
|
jsd_thresholds: torch.Tensor,
|
|
|
) -> torch.Tensor:
|
|
|
- """Applies Jensen-Shannon Distance-based sampling to filter tokens.
|
|
|
+ """Applies Jensen-Shannon Divergence-based sampling to filter tokens.
|
|
|
|
|
|
Args:
|
|
|
logits: Tensor of shape (batch_size, vocab_size) containing the logits.
|
|
|
jsd_thresholds: Tensor of shape (batch_size,) with the JSD threshold for each sequence.
|
|
|
|
|
|
Returns:
|
|
|
- Modified logits tensor with tokens beyond the JSD threshold masked out.
|
|
|
+ Modified logits tensor with tokens exceeding the JSD threshold masked out.
|
|
|
"""
|
|
|
# Compute the probability distribution from logits
|
|
|
- probs = torch.softmax(logits, dim=-1)
|
|
|
+ probs = torch.softmax(logits, dim=-1) # Shape: (batch_size, vocab_size)
|
|
|
|
|
|
# Create a uniform distribution
|
|
|
- uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1))
|
|
|
+ uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1)) # Shape: (batch_size, vocab_size)
|
|
|
|
|
|
# Compute the average distribution
|
|
|
- average_probs = 0.5 * (probs + uniform_probs)
|
|
|
+ average_probs = 0.5 * (probs + uniform_probs) # Shape: (batch_size, vocab_size)
|
|
|
|
|
|
- # Compute the Kullback-Leibler divergences
|
|
|
- kl_probs = torch.sum(probs * (torch.log(probs + 1e-8) - torch.log(average_probs + 1e-8)), dim=-1)
|
|
|
- kl_uniform = torch.sum(uniform_probs * (torch.log(uniform_probs + 1e-8) - torch.log(average_probs + 1e-8)), dim=-1)
|
|
|
+ # Compute per-token KL divergences
|
|
|
+ kl_probs = probs * (torch.log(probs + 1e-8) - torch.log(average_probs + 1e-8)) # Shape: (batch_size, vocab_size)
|
|
|
+ kl_uniform = uniform_probs * (torch.log(uniform_probs + 1e-8) - torch.log(average_probs + 1e-8)) # Shape: (batch_size, vocab_size)
|
|
|
|
|
|
- # Calculate the Jensen-Shannon Distance
|
|
|
- jsd = 0.5 * (kl_probs + kl_uniform)
|
|
|
+ # Compute per-token JSD
|
|
|
+ jsd = 0.5 * (kl_probs + kl_uniform) # Shape: (batch_size, vocab_size)
|
|
|
|
|
|
- # Create a mask for tokens where JSD is less than the threshold
|
|
|
- jsd_mask = jsd.unsqueeze(-1) < jsd_thresholds.unsqueeze(-1)
|
|
|
+ # Compare per-token JSD with threshold
|
|
|
+ jsd_thresholds_expanded = jsd_thresholds.unsqueeze(-1) # Shape: (batch_size, 1)
|
|
|
+ jsd_mask = jsd >= jsd_thresholds_expanded # Shape: (batch_size, vocab_size)
|
|
|
|
|
|
# Mask out tokens where the JSD exceeds the threshold
|
|
|
- logits = logits.masked_fill(~jsd_mask, -float("inf"))
|
|
|
+ logits = logits.masked_fill(jsd_mask, -float("inf"))
|
|
|
return logits
|
|
|
|
|
|
|