|
@@ -59,6 +59,12 @@ class Sampler(nn.Module):
|
|
|
|
|
|
logits = _apply_logits_processors(input_metadata, logits, output_tokens)
|
|
|
|
|
|
+ # Apply Tail Free Sampling, as described in https://www.trentonbricken.com/Tail-Free-Sampling/
|
|
|
+ tfss = _get_tfs(input_metadata)
|
|
|
+ assert len(tfss) == logits.shape[0]
|
|
|
+ if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
|
|
|
+ logits = _apply_tfs(logits, tfss)
|
|
|
+
|
|
|
# Apply temperature scaling.
|
|
|
temperatures = _get_temperatures(input_metadata)
|
|
|
assert len(temperatures) == logits.shape[0]
|
|
@@ -269,6 +275,16 @@ def _get_top_a_top_p_top_k(
|
|
|
return top_ps, top_ks, top_as
|
|
|
|
|
|
|
|
|
+
|
|
|
+def _get_tfs(input_metadata: InputMetadata) -> List[float]:
|
|
|
+ tfss: List[float] = []
|
|
|
+ for seq_group in input_metadata.seq_groups:
|
|
|
+ seq_ids, sampling_params = seq_group
|
|
|
+ z = sampling_params.tfs
|
|
|
+ tfss += [z] * len(seq_ids)
|
|
|
+ return tfss
|
|
|
+
|
|
|
+
|
|
|
def _apply_top_a_top_p_top_k(
|
|
|
logits: torch.Tensor,
|
|
|
top_ps: List[float],
|
|
@@ -302,6 +318,34 @@ def _apply_top_a_top_p_top_k(
|
|
|
index=torch.argsort(logits_idx, dim=-1))
|
|
|
return logits
|
|
|
|
|
|
+def _apply_tfs(
|
|
|
+ logits: torch.Tensor,
|
|
|
+ tfss: List[float],
|
|
|
+) -> torch.Tensor:
|
|
|
+ z = torch.tensor(tfss, dtype=logits.dtype, device=logits.device)
|
|
|
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
|
|
+ d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
|
|
|
+ normalized_d2 = d2 / torch.sum(d2, dim=-1)
|
|
|
+ curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
|
|
|
+
|
|
|
+ tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
|
|
|
+
|
|
|
+ tfs_mask = torch.cat(
|
|
|
+ (
|
|
|
+ torch.zeros(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
|
|
|
+ tfs_mask,
|
|
|
+ torch.ones(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
|
|
|
+ ),
|
|
|
+ dim=-1,
|
|
|
+ )
|
|
|
+
|
|
|
+ logits_sort[tfs_mask] = -float("inf")
|
|
|
+ logits = torch.gather(logits_sort,
|
|
|
+ dim=-1,
|
|
|
+ index=torch.argsort(logits_idx, dim=-1))
|
|
|
+
|
|
|
+ return logits
|
|
|
+
|
|
|
|
|
|
def _get_topk_logprobs(
|
|
|
logprobs: torch.Tensor,
|