|
@@ -50,12 +50,12 @@ class Sampler(nn.Module):
|
|
# Apply presence and frequency penalties.
|
|
# Apply presence and frequency penalties.
|
|
output_tokens = _get_output_tokens(input_metadata)
|
|
output_tokens = _get_output_tokens(input_metadata)
|
|
assert len(output_tokens) == logits.shape[0]
|
|
assert len(output_tokens) == logits.shape[0]
|
|
- presence_penalties, frequency_penalties = _get_penalties(
|
|
|
|
- input_metadata)
|
|
|
|
|
|
+ presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(input_metadata)
|
|
assert len(presence_penalties) == logits.shape[0]
|
|
assert len(presence_penalties) == logits.shape[0]
|
|
assert len(frequency_penalties) == logits.shape[0]
|
|
assert len(frequency_penalties) == logits.shape[0]
|
|
- logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
|
|
|
- frequency_penalties)
|
|
|
|
|
|
+ logits = _apply_penalties(logits, output_tokens,
|
|
|
|
+ presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
|
+ self.vocab_size)
|
|
|
|
|
|
logits = _apply_logits_processors(input_metadata, logits, output_tokens)
|
|
logits = _apply_logits_processors(input_metadata, logits, output_tokens)
|
|
|
|
|
|
@@ -75,13 +75,14 @@ class Sampler(nn.Module):
|
|
# Use in-place division to avoid creating a new tensor.
|
|
# Use in-place division to avoid creating a new tensor.
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
|
|
|
|
- # Apply top-p and top-k truncation.
|
|
|
|
- top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
|
|
|
|
|
+ # Apply top-p, top-k, and top-a truncation.
|
|
|
|
+ top_ps, top_ks, top_as = _get_top_a_top_p_top_k(input_metadata, self.vocab_size)
|
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
- if do_top_p or do_top_k:
|
|
|
|
- logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
|
|
|
|
|
+ do_top_a = any(a > _SAMPLING_EPS for a in top_as)
|
|
|
|
+ if do_top_p or do_top_k or do_top_a:
|
|
|
|
+ logits = _apply_top_a_top_p_top_k(logits, top_ps, top_ks, top_as)
|
|
|
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
# We use float32 for probabilities and log probabilities.
|
|
# Compute the probabilities.
|
|
# Compute the probabilities.
|
|
@@ -142,13 +143,13 @@ def _get_penalties(
|
|
# Collect the presence and frequency penalties.
|
|
# Collect the presence and frequency penalties.
|
|
presence_penalties: List[float] = []
|
|
presence_penalties: List[float] = []
|
|
frequency_penalties: List[float] = []
|
|
frequency_penalties: List[float] = []
|
|
|
|
+ repetition_penalties: List[float] = []
|
|
for seq_group in input_metadata.seq_groups:
|
|
for seq_group in input_metadata.seq_groups:
|
|
seq_ids, sampling_params = seq_group
|
|
seq_ids, sampling_params = seq_group
|
|
- p = sampling_params.presence_penalty
|
|
|
|
- f = sampling_params.frequency_penalty
|
|
|
|
- presence_penalties += [p] * len(seq_ids)
|
|
|
|
- frequency_penalties += [f] * len(seq_ids)
|
|
|
|
- return presence_penalties, frequency_penalties
|
|
|
|
|
|
+ presence_penalties += [sampling_params.presence_penalty] * len(seq_ids)
|
|
|
|
+ frequency_penalties += [sampling_params.frequency_penalty] * len(seq_ids)
|
|
|
|
+ repetition_penalties += [sampling_params.repetition_penalty] * len(seq_ids)
|
|
|
|
+ return presence_penalties, frequency_penalties, repetition_penalties
|
|
|
|
|
|
|
|
|
|
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
|
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
|
@@ -180,14 +181,16 @@ def _apply_penalties(
|
|
output_tokens: List[List[int]],
|
|
output_tokens: List[List[int]],
|
|
presence_penalties: List[float],
|
|
presence_penalties: List[float],
|
|
frequency_penalties: List[float],
|
|
frequency_penalties: List[float],
|
|
|
|
+ repetition_penalties: List[float],
|
|
|
|
+ vocab_size: int,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
num_seqs, vocab_size = logits.shape
|
|
num_seqs, vocab_size = logits.shape
|
|
for i in range(num_seqs):
|
|
for i in range(num_seqs):
|
|
if not output_tokens[i]:
|
|
if not output_tokens[i]:
|
|
continue
|
|
continue
|
|
- p = presence_penalties[i]
|
|
|
|
- f = frequency_penalties[i]
|
|
|
|
- if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
|
|
|
|
|
+ if (abs(presence_penalties[i]) < _SAMPLING_EPS and
|
|
|
|
+ abs(frequency_penalties[i]) < _SAMPLING_EPS and
|
|
|
|
+ repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
|
|
continue
|
|
continue
|
|
break
|
|
break
|
|
else:
|
|
else:
|
|
@@ -218,11 +221,21 @@ def _apply_penalties(
|
|
presence_penalties = torch.tensor(presence_penalties,
|
|
presence_penalties = torch.tensor(presence_penalties,
|
|
dtype=logits.dtype,
|
|
dtype=logits.dtype,
|
|
device=logits.device)
|
|
device=logits.device)
|
|
|
|
+ repetition_penalties = torch.tensor(repetition_penalties,
|
|
|
|
+ dtype=logits.dtype,
|
|
|
|
+ device=logits.device)
|
|
|
|
|
|
# We follow the definition in OpenAI API.
|
|
# We follow the definition in OpenAI API.
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
|
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
|
- logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
|
|
|
|
|
+ presence_mask = (bin_counts > 0)
|
|
|
|
+ logits -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
|
|
|
+
|
|
|
|
+ # Effectively: If token is present and logit is positive, divide logit by rep_pen.
|
|
|
|
+ # If token is present and logit is negative, multiply logit by rep_pen.
|
|
|
|
+ logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits > 0)
|
|
|
|
+ logits += logits * (repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits < 0)
|
|
|
|
+
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
@@ -241,22 +254,26 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
|
return temperatures
|
|
return temperatures
|
|
|
|
|
|
|
|
|
|
-def _get_top_p_top_k(
|
|
|
|
|
|
+def _get_top_a_top_p_top_k(
|
|
input_metadata: InputMetadata,
|
|
input_metadata: InputMetadata,
|
|
vocab_size: int,
|
|
vocab_size: int,
|
|
-) -> Tuple[List[float], List[int]]:
|
|
|
|
|
|
+) -> Tuple[List[float], List[int], List[float]]:
|
|
top_ps: List[float] = []
|
|
top_ps: List[float] = []
|
|
top_ks: List[int] = []
|
|
top_ks: List[int] = []
|
|
|
|
+ top_as: List[float] = []
|
|
for seq_group in input_metadata.seq_groups:
|
|
for seq_group in input_metadata.seq_groups:
|
|
seq_ids, sampling_params = seq_group
|
|
seq_ids, sampling_params = seq_group
|
|
- top_p = sampling_params.top_p
|
|
|
|
# k should not be greater than the vocab size.
|
|
# k should not be greater than the vocab size.
|
|
top_k = min(sampling_params.top_k, vocab_size)
|
|
top_k = min(sampling_params.top_k, vocab_size)
|
|
# k=-1 means no truncation.
|
|
# k=-1 means no truncation.
|
|
top_k = vocab_size if top_k == -1 else top_k
|
|
top_k = vocab_size if top_k == -1 else top_k
|
|
- top_ps += [top_p] * len(seq_ids)
|
|
|
|
|
|
+
|
|
|
|
+ top_ps += [sampling_params.top_p] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
- return top_ps, top_ks
|
|
|
|
|
|
+ top_as += [sampling_params.top_a] * len(seq_ids)
|
|
|
|
+
|
|
|
|
+ return top_ps, top_ks, top_as
|
|
|
|
+
|
|
|
|
|
|
|
|
|
|
def _get_tfs(input_metadata: InputMetadata) -> List[float]:
|
|
def _get_tfs(input_metadata: InputMetadata) -> List[float]:
|
|
@@ -268,26 +285,31 @@ def _get_tfs(input_metadata: InputMetadata) -> List[float]:
|
|
return tfss
|
|
return tfss
|
|
|
|
|
|
|
|
|
|
-def _apply_top_p_top_k(
|
|
|
|
|
|
+def _apply_top_a_top_p_top_k(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
top_ps: List[float],
|
|
top_ps: List[float],
|
|
top_ks: List[int],
|
|
top_ks: List[int],
|
|
|
|
+ top_as: List[float],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
|
|
|
- k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
|
|
|
|
|
+ ts_p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
|
|
|
+ ts_k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
|
|
|
+ ts_a = torch.tensor(top_as, dtype=logits.dtype, device=logits.device)
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
|
|
|
|
|
- # Apply top-p.
|
|
|
|
|
|
+ # Apply top-p and top-a.
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
- top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
|
|
|
- logits_sort[top_p_mask] = -float("inf")
|
|
|
|
-
|
|
|
|
|
|
+ top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * ts_a
|
|
|
|
+ top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)) # Cull logits below the top-a threshold
|
|
|
|
+ top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
|
|
|
|
+ top_ap_mask[:, 0] = False # Guarantee at least one token is pickable
|
|
|
|
+ logits_sort[top_ap_mask] = -float("inf")
|
|
|
|
+
|
|
# Apply top-k.
|
|
# Apply top-k.
|
|
# Create a mask for the top-k elements.
|
|
# Create a mask for the top-k elements.
|
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
|
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
|
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
|
- top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
|
|
|
|
|
+ top_k_mask = top_k_mask >= ts_k.unsqueeze(dim=1)
|
|
logits_sort[top_k_mask] = -float("inf")
|
|
logits_sort[top_k_mask] = -float("inf")
|
|
|
|
|
|
# Re-sort the probabilities.
|
|
# Re-sort the probabilities.
|