|
@@ -59,12 +59,29 @@ class Sampler(nn.Module):
|
|
|
|
|
|
logits = _apply_logits_processors(input_metadata, logits, output_tokens)
|
|
logits = _apply_logits_processors(input_metadata, logits, output_tokens)
|
|
|
|
|
|
|
|
+ # Apply Eta sampling, as described in https://arxiv.org/abs/2210.15191
|
|
|
|
+ eta_cutoffs = _get_eta_cutoffs(input_metadata)
|
|
|
|
+ assert len(eta_cutoffs) == logits.shape[0]
|
|
|
|
+ if any(eta > _SAMPLING_EPS for eta in eta_cutoffs):
|
|
|
|
+ logits = _apply_eta_cutoff(logits, eta_cutoffs)
|
|
|
|
+
|
|
|
|
+ # Apply Locally typical sampling, as described in https://arxiv.org/abs/2202.00666
|
|
|
|
+ typical_ps = _get_typical_ps(input_metadata)
|
|
|
|
+ assert len(typical_ps) == logits.shape[0]
|
|
|
|
+ if any(typ_p < 1.0 - _SAMPLING_EPS for typ_p in typical_ps):
|
|
|
|
+ logits = _apply_typical_sampling(logits, typical_ps)
|
|
|
|
+
|
|
# Apply Tail Free Sampling, as described in https://www.trentonbricken.com/Tail-Free-Sampling/
|
|
# Apply Tail Free Sampling, as described in https://www.trentonbricken.com/Tail-Free-Sampling/
|
|
tfss = _get_tfs(input_metadata)
|
|
tfss = _get_tfs(input_metadata)
|
|
assert len(tfss) == logits.shape[0]
|
|
assert len(tfss) == logits.shape[0]
|
|
if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
|
|
if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
|
|
logits = _apply_tfs(logits, tfss)
|
|
logits = _apply_tfs(logits, tfss)
|
|
|
|
|
|
|
|
+ epsilon_cutoffs = _get_epsilon_cutoffs(input_metadata)
|
|
|
|
+ assert len(epsilon_cutoffs) == logits.shape[0]
|
|
|
|
+ if any(epsilon > _SAMPLING_EPS for epsilon in epsilon_cutoffs):
|
|
|
|
+ logits = _apply_epsilon_cutoff(logits, epsilon_cutoffs)
|
|
|
|
+
|
|
# Apply temperature scaling.
|
|
# Apply temperature scaling.
|
|
temperatures = _get_temperatures(input_metadata)
|
|
temperatures = _get_temperatures(input_metadata)
|
|
assert len(temperatures) == logits.shape[0]
|
|
assert len(temperatures) == logits.shape[0]
|
|
@@ -285,6 +302,33 @@ def _get_tfs(input_metadata: InputMetadata) -> List[float]:
|
|
return tfss
|
|
return tfss
|
|
|
|
|
|
|
|
|
|
|
|
+def _get_eta_cutoffs(input_metadata: InputMetadata) -> List[float]:
|
|
|
|
+ eta_cutoffs: List[float] = []
|
|
|
|
+ for seq_group in input_metadata.seq_groups:
|
|
|
|
+ seq_ids, sampling_params = seq_group
|
|
|
|
+ eta_cutoff = sampling_params.eta_cutoff
|
|
|
|
+ eta_cutoffs += [eta_cutoff] * len(seq_ids)
|
|
|
|
+ return eta_cutoffs
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _get_epsilon_cutoffs(input_metadata: InputMetadata) -> List[float]:
|
|
|
|
+ epsilon_cutoffs: List[float] = []
|
|
|
|
+ for seq_group in input_metadata.seq_groups:
|
|
|
|
+ seq_ids, sampling_params = seq_group
|
|
|
|
+ epsilon_cutoff = sampling_params.epsilon_cutoff
|
|
|
|
+ epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
|
|
|
|
+ return epsilon_cutoffs
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _get_typical_ps(input_metadata: InputMetadata) -> List[float]:
|
|
|
|
+ typical_ps: List[float] = []
|
|
|
|
+ for seq_group in input_metadata.seq_groups:
|
|
|
|
+ seq_ids, sampling_params = seq_group
|
|
|
|
+ typical_p = sampling_params.typical_p
|
|
|
|
+ typical_ps += [typical_p] * len(seq_ids)
|
|
|
|
+ return typical_ps
|
|
|
|
+
|
|
|
|
+
|
|
def _apply_top_a_top_p_top_k(
|
|
def _apply_top_a_top_p_top_k(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
top_ps: List[float],
|
|
top_ps: List[float],
|
|
@@ -347,6 +391,71 @@ def _apply_tfs(
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
+def _apply_eta_cutoff(
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
+ eta_cutoffs: List[float],
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
+ eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device) * 1e-4
|
|
|
|
+ shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
|
|
+ probs = shifted_logits.exp()
|
|
|
|
+
|
|
|
|
+ neg_entropy = (probs * shifted_logits).nansum(dim=-1)
|
|
|
|
+ eps = torch.min(eta, torch.sqrt(eta)*torch.exp(neg_entropy)).unsqueeze(dim=1)
|
|
|
|
+
|
|
|
|
+ eta_mask = probs < eps
|
|
|
|
+
|
|
|
|
+ if(torch.all(eta_mask)): # guard against nulling out all the logits
|
|
|
|
+ topk_prob, _ = torch.max(probs, dim=-1)
|
|
|
|
+ eta_mask = probs < topk_prob
|
|
|
|
+
|
|
|
|
+ logits[eta_mask] = -float("inf")
|
|
|
|
+ return logits
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _apply_epsilon_cutoff(
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
+ epsilon_cutoffs: List[float],
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
+ eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device).unsqueeze(dim=1)
|
|
|
|
+ probs = logits.softmax(dim=-1)
|
|
|
|
+
|
|
|
|
+ eps_mask = probs < (eps * 1e-4)
|
|
|
|
+
|
|
|
|
+ if(torch.all(eps_mask)): # guard against nulling out all the logits
|
|
|
|
+ topk_prob, _ = torch.max(probs, dim=-1)
|
|
|
|
+ eps_mask = probs < topk_prob
|
|
|
|
+
|
|
|
|
+ logits[eps_mask] = -float("inf")
|
|
|
|
+ return logits
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _apply_typical_sampling(
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
+ typical_ps: List[float],
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
+ typ_p = torch.tensor(typical_ps, dtype=logits.dtype, device=logits.device)
|
|
|
|
+ shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
|
|
+ probs = shifted_logits.exp()
|
|
|
|
+
|
|
|
|
+ neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
|
|
|
|
+
|
|
|
|
+ surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
|
|
+ _, indices = torch.sort(surprisal_deviations)
|
|
|
|
+ reordered_probs = probs.gather(-1, indices)
|
|
|
|
+ typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
|
|
|
|
+
|
|
|
|
+ min_tokens_to_keep = 1
|
|
|
|
+ # Keep at least min_tokens_to_keep
|
|
|
|
+ typ_mask_sorted[..., :min_tokens_to_keep] = 0
|
|
|
|
+
|
|
|
|
+ typ_mask = typ_mask_sorted.scatter(
|
|
|
|
+ 1, indices, typ_mask_sorted
|
|
|
|
+ )
|
|
|
|
+ logits[typ_mask] = -float("inf")
|
|
|
|
+ return logits
|
|
|
|
+
|
|
|
|
+
|
|
def _get_topk_logprobs(
|
|
def _get_topk_logprobs(
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
num_logprobs: Optional[int],
|
|
num_logprobs: Optional[int],
|