|
@@ -7,12 +7,11 @@ import torch.nn as nn
|
|
|
|
|
|
from aphrodite.modeling.metadata import InputMetadata
|
|
from aphrodite.modeling.metadata import InputMetadata
|
|
from aphrodite.modeling.megatron.communication_op import (
|
|
from aphrodite.modeling.megatron.communication_op import (
|
|
- tensor_model_parallel_all_gather
|
|
|
|
-)
|
|
|
|
|
|
+ tensor_model_parallel_all_gather)
|
|
from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
-from aphrodite.common.sequence import (
|
|
|
|
- PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData,
|
|
|
|
- SequenceGroupOutputs, SequenceOutputs)
|
|
|
|
|
|
+from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
|
|
|
|
+ SamplerOutput, SequenceData,
|
|
|
|
+ SequenceGroupOutputs, SequenceOutputs)
|
|
from aphrodite.common.sequence import SamplerOutput, SequenceOutputs, SequenceData
|
|
from aphrodite.common.sequence import SamplerOutput, SequenceOutputs, SequenceData
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
_SAMPLING_EPS = 1e-5
|
|
@@ -54,18 +53,20 @@ class Sampler(nn.Module):
|
|
# Apply presence and frequency penalties.
|
|
# Apply presence and frequency penalties.
|
|
output_tokens = _get_output_tokens(input_metadata)
|
|
output_tokens = _get_output_tokens(input_metadata)
|
|
assert len(output_tokens) == logits.shape[0]
|
|
assert len(output_tokens) == logits.shape[0]
|
|
- presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(input_metadata)
|
|
|
|
|
|
+ presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
|
|
|
|
+ input_metadata)
|
|
assert len(presence_penalties) == logits.shape[0]
|
|
assert len(presence_penalties) == logits.shape[0]
|
|
assert len(frequency_penalties) == logits.shape[0]
|
|
assert len(frequency_penalties) == logits.shape[0]
|
|
- logits = _apply_penalties(logits, output_tokens,
|
|
|
|
- presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
|
|
|
+ logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
|
|
|
+ frequency_penalties, repetition_penalties,
|
|
self.vocab_size)
|
|
self.vocab_size)
|
|
-
|
|
|
|
|
|
+
|
|
banned_tokens = _get_custom_token_bans(input_metadata)
|
|
banned_tokens = _get_custom_token_bans(input_metadata)
|
|
assert len(banned_tokens) == logits.shape[0]
|
|
assert len(banned_tokens) == logits.shape[0]
|
|
logits = _apply_token_bans(logits, banned_tokens)
|
|
logits = _apply_token_bans(logits, banned_tokens)
|
|
-
|
|
|
|
- logits = _apply_logits_processors(input_metadata, logits, output_tokens)
|
|
|
|
|
|
+
|
|
|
|
+ logits = _apply_logits_processors(input_metadata, logits,
|
|
|
|
+ output_tokens)
|
|
|
|
|
|
# Apply Eta sampling, as described in https://arxiv.org/abs/2210.15191
|
|
# Apply Eta sampling, as described in https://arxiv.org/abs/2210.15191
|
|
eta_cutoffs = _get_eta_cutoffs(input_metadata)
|
|
eta_cutoffs = _get_eta_cutoffs(input_metadata)
|
|
@@ -101,7 +102,8 @@ class Sampler(nn.Module):
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
|
|
|
|
# Apply top-p, top-k, and top-a truncation.
|
|
# Apply top-p, top-k, and top-a truncation.
|
|
- top_ps, top_ks, top_as = _get_top_a_top_p_top_k(input_metadata, self.vocab_size)
|
|
|
|
|
|
+ top_ps, top_ks, top_as = _get_top_a_top_p_top_k(
|
|
|
|
+ input_metadata, self.vocab_size)
|
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
@@ -141,7 +143,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
def _prune_hidden_states(
|
|
def _prune_hidden_states(
|
|
hidden_states: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_metadata: InputMetadata,
|
|
input_metadata: InputMetadata,
|
|
-) -> torch.Tensor:
|
|
|
|
|
|
+) -> torch.Tensor:
|
|
selected_token_indices: List[int] = []
|
|
selected_token_indices: List[int] = []
|
|
start_idx = 0
|
|
start_idx = 0
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
@@ -166,6 +168,7 @@ def _prune_hidden_states(
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
return hidden_states.index_select(0, selected_token_indices)
|
|
return hidden_states.index_select(0, selected_token_indices)
|
|
|
|
|
|
|
|
+
|
|
def _get_penalties(
|
|
def _get_penalties(
|
|
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
|
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
|
# Collect the presence and frequency penalties.
|
|
# Collect the presence and frequency penalties.
|
|
@@ -181,8 +184,10 @@ def _get_penalties(
|
|
frequency_penalties += [0] * (prompt_len - 1)
|
|
frequency_penalties += [0] * (prompt_len - 1)
|
|
repetition_penalties += [0] * (prompt_len - 1)
|
|
repetition_penalties += [0] * (prompt_len - 1)
|
|
presence_penalties += [sampling_params.presence_penalty] * len(seq_ids)
|
|
presence_penalties += [sampling_params.presence_penalty] * len(seq_ids)
|
|
- frequency_penalties += [sampling_params.frequency_penalty] * len(seq_ids)
|
|
|
|
- repetition_penalties += [sampling_params.repetition_penalty] * len(seq_ids)
|
|
|
|
|
|
+ frequency_penalties += [sampling_params.frequency_penalty
|
|
|
|
+ ] * len(seq_ids)
|
|
|
|
+ repetition_penalties += [sampling_params.repetition_penalty
|
|
|
|
+ ] * len(seq_ids)
|
|
return presence_penalties, frequency_penalties, repetition_penalties
|
|
return presence_penalties, frequency_penalties, repetition_penalties
|
|
|
|
|
|
|
|
|
|
@@ -215,14 +220,12 @@ def _get_custom_token_bans(input_metadata: InputMetadata) -> List[List[int]]:
|
|
return banned_tokens
|
|
return banned_tokens
|
|
|
|
|
|
|
|
|
|
-def _apply_logits_processors(
|
|
|
|
- input_metadata: InputMetadata,
|
|
|
|
- logits: torch.Tensor,
|
|
|
|
- output_tokens: List[List[int]]
|
|
|
|
-) -> torch.Tensor:
|
|
|
|
|
|
+def _apply_logits_processors(input_metadata: InputMetadata,
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
+ output_tokens: List[List[int]]) -> torch.Tensor:
|
|
seq_offset = 0
|
|
seq_offset = 0
|
|
|
|
|
|
- for seq_ids,sampling_params in input_metadata.seq_groups:
|
|
|
|
|
|
+ for seq_ids, sampling_params in input_metadata.seq_groups:
|
|
seq_end = seq_offset + len(seq_ids)
|
|
seq_end = seq_offset + len(seq_ids)
|
|
|
|
|
|
for proc in sampling_params.logits_processors:
|
|
for proc in sampling_params.logits_processors:
|
|
@@ -232,6 +235,7 @@ def _apply_logits_processors(
|
|
|
|
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
+
|
|
def _apply_penalties(
|
|
def _apply_penalties(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
output_tokens: List[List[int]],
|
|
output_tokens: List[List[int]],
|
|
@@ -244,9 +248,9 @@ def _apply_penalties(
|
|
for i in range(num_seqs):
|
|
for i in range(num_seqs):
|
|
if not output_tokens[i]:
|
|
if not output_tokens[i]:
|
|
continue
|
|
continue
|
|
- if (abs(presence_penalties[i]) < _SAMPLING_EPS and
|
|
|
|
- abs(frequency_penalties[i]) < _SAMPLING_EPS and
|
|
|
|
- repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
|
|
|
|
|
|
+ if (abs(presence_penalties[i]) < _SAMPLING_EPS
|
|
|
|
+ and abs(frequency_penalties[i]) < _SAMPLING_EPS
|
|
|
|
+ and repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
|
|
continue
|
|
continue
|
|
break
|
|
break
|
|
else:
|
|
else:
|
|
@@ -278,8 +282,8 @@ def _apply_penalties(
|
|
dtype=logits.dtype,
|
|
dtype=logits.dtype,
|
|
device=logits.device)
|
|
device=logits.device)
|
|
repetition_penalties = torch.tensor(repetition_penalties,
|
|
repetition_penalties = torch.tensor(repetition_penalties,
|
|
- dtype=logits.dtype,
|
|
|
|
- device=logits.device)
|
|
|
|
|
|
+ dtype=logits.dtype,
|
|
|
|
+ device=logits.device)
|
|
|
|
|
|
# We follow the definition in OpenAI API.
|
|
# We follow the definition in OpenAI API.
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
@@ -289,13 +293,16 @@ def _apply_penalties(
|
|
|
|
|
|
# Effectively: If token is present and logit is positive, divide logit by rep_pen.
|
|
# Effectively: If token is present and logit is positive, divide logit by rep_pen.
|
|
# If token is present and logit is negative, multiply logit by rep_pen.
|
|
# If token is present and logit is negative, multiply logit by rep_pen.
|
|
- logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits > 0)
|
|
|
|
- logits += logits * (repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits < 0)
|
|
|
|
|
|
+ logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) -
|
|
|
|
+ 1) * presence_mask * (logits > 0)
|
|
|
|
+ logits += logits * (repetition_penalties.unsqueeze(dim=1) -
|
|
|
|
+ 1) * presence_mask * (logits < 0)
|
|
|
|
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
-def _apply_token_bans(logits: torch.Tensor, banned_tokens: List[List[int]]) -> torch.Tensor:
|
|
|
|
|
|
+def _apply_token_bans(logits: torch.Tensor,
|
|
|
|
+ banned_tokens: List[List[int]]) -> torch.Tensor:
|
|
for i, banned_token_ids in enumerate(banned_tokens):
|
|
for i, banned_token_ids in enumerate(banned_tokens):
|
|
if not banned_token_ids:
|
|
if not banned_token_ids:
|
|
continue
|
|
continue
|
|
@@ -340,7 +347,7 @@ def _get_top_a_top_p_top_k(
|
|
prompt_len = input_metadata.prompt_lens[i]
|
|
prompt_len = input_metadata.prompt_lens[i]
|
|
top_ps += [sampling_params.top_p] * (prompt_len - 1)
|
|
top_ps += [sampling_params.top_p] * (prompt_len - 1)
|
|
top_ks += [top_k] * (prompt_len - 1)
|
|
top_ks += [top_k] * (prompt_len - 1)
|
|
- top_as += [sampling_params.top_a] * (prompt_len - 1)
|
|
|
|
|
|
+ top_as += [sampling_params.top_a] * (prompt_len - 1)
|
|
top_ps += [sampling_params.top_p] * len(seq_ids)
|
|
top_ps += [sampling_params.top_p] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
top_ks += [top_k] * len(seq_ids)
|
|
top_as += [sampling_params.top_a] * len(seq_ids)
|
|
top_as += [sampling_params.top_a] * len(seq_ids)
|
|
@@ -415,11 +422,13 @@ def _apply_top_a_top_p_top_k(
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * ts_a
|
|
top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * ts_a
|
|
- top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)) # Cull logits below the top-a threshold
|
|
|
|
- top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
|
|
|
|
- top_ap_mask[:, 0] = False # Guarantee at least one token is pickable
|
|
|
|
|
|
+ top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)
|
|
|
|
+ ) # Cull logits below the top-a threshold
|
|
|
|
+ top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(
|
|
|
|
+ dim=1)) # Cull logits above the top-p summation threshold
|
|
|
|
+ top_ap_mask[:, 0] = False # Guarantee at least one token is pickable
|
|
logits_sort[top_ap_mask] = -float("inf")
|
|
logits_sort[top_ap_mask] = -float("inf")
|
|
-
|
|
|
|
|
|
+
|
|
# Apply top-k.
|
|
# Apply top-k.
|
|
# Create a mask for the top-k elements.
|
|
# Create a mask for the top-k elements.
|
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
|
@@ -433,6 +442,7 @@ def _apply_top_a_top_p_top_k(
|
|
index=torch.argsort(logits_idx, dim=-1))
|
|
index=torch.argsort(logits_idx, dim=-1))
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
+
|
|
def _apply_tfs(
|
|
def _apply_tfs(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
tfss: List[float],
|
|
tfss: List[float],
|
|
@@ -446,14 +456,16 @@ def _apply_tfs(
|
|
tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
|
|
tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
|
|
|
|
|
|
tfs_mask = torch.cat(
|
|
tfs_mask = torch.cat(
|
|
- (
|
|
|
|
- torch.zeros(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
|
|
|
|
- tfs_mask,
|
|
|
|
- torch.ones(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
|
|
|
|
- ),
|
|
|
|
- dim=-1,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
|
|
+ (
|
|
|
|
+ torch.zeros(
|
|
|
|
+ logits.shape[0], 1, dtype=torch.bool, device=logits.device),
|
|
|
|
+ tfs_mask,
|
|
|
|
+ torch.ones(
|
|
|
|
+ logits.shape[0], 1, dtype=torch.bool, device=logits.device),
|
|
|
|
+ ),
|
|
|
|
+ dim=-1,
|
|
|
|
+ )
|
|
|
|
+
|
|
logits_sort[tfs_mask] = -float("inf")
|
|
logits_sort[tfs_mask] = -float("inf")
|
|
logits = torch.gather(logits_sort,
|
|
logits = torch.gather(logits_sort,
|
|
dim=-1,
|
|
dim=-1,
|
|
@@ -462,21 +474,22 @@ def _apply_tfs(
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
def _apply_eta_cutoff(
|
|
def _apply_eta_cutoff(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
eta_cutoffs: List[float],
|
|
eta_cutoffs: List[float],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device) * 1e-4
|
|
|
|
|
|
+ eta = torch.tensor(eta_cutoffs, dtype=logits.dtype,
|
|
|
|
+ device=logits.device) * 1e-4
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
probs = shifted_logits.exp()
|
|
probs = shifted_logits.exp()
|
|
|
|
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1)
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1)
|
|
- eps = torch.min(eta, torch.sqrt(eta)*torch.exp(neg_entropy)).unsqueeze(dim=1)
|
|
|
|
|
|
+ eps = torch.min(eta,
|
|
|
|
+ torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
|
|
|
|
|
|
eta_mask = probs < eps
|
|
eta_mask = probs < eps
|
|
|
|
|
|
- if(torch.all(eta_mask)): # guard against nulling out all the logits
|
|
|
|
|
|
+ if (torch.all(eta_mask)): # guard against nulling out all the logits
|
|
topk_prob, _ = torch.max(probs, dim=-1)
|
|
topk_prob, _ = torch.max(probs, dim=-1)
|
|
eta_mask = probs < topk_prob
|
|
eta_mask = probs < topk_prob
|
|
|
|
|
|
@@ -488,12 +501,14 @@ def _apply_epsilon_cutoff(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
epsilon_cutoffs: List[float],
|
|
epsilon_cutoffs: List[float],
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device).unsqueeze(dim=1)
|
|
|
|
|
|
+ eps = torch.tensor(epsilon_cutoffs,
|
|
|
|
+ dtype=logits.dtype,
|
|
|
|
+ device=logits.device).unsqueeze(dim=1)
|
|
probs = logits.softmax(dim=-1)
|
|
probs = logits.softmax(dim=-1)
|
|
|
|
|
|
eps_mask = probs < (eps * 1e-4)
|
|
eps_mask = probs < (eps * 1e-4)
|
|
|
|
|
|
- if(torch.all(eps_mask)): # guard against nulling out all the logits
|
|
|
|
|
|
+ if (torch.all(eps_mask)): # guard against nulling out all the logits
|
|
topk_prob, _ = torch.max(probs, dim=-1)
|
|
topk_prob, _ = torch.max(probs, dim=-1)
|
|
eps_mask = probs < topk_prob
|
|
eps_mask = probs < topk_prob
|
|
|
|
|
|
@@ -515,17 +530,16 @@ def _apply_typical_sampling(
|
|
_, indices = torch.sort(surprisal_deviations)
|
|
_, indices = torch.sort(surprisal_deviations)
|
|
reordered_probs = probs.gather(-1, indices)
|
|
reordered_probs = probs.gather(-1, indices)
|
|
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
|
|
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
|
|
-
|
|
|
|
|
|
+
|
|
min_tokens_to_keep = 1
|
|
min_tokens_to_keep = 1
|
|
# Keep at least min_tokens_to_keep
|
|
# Keep at least min_tokens_to_keep
|
|
typ_mask_sorted[..., :min_tokens_to_keep] = 0
|
|
typ_mask_sorted[..., :min_tokens_to_keep] = 0
|
|
|
|
|
|
- typ_mask = typ_mask_sorted.scatter(
|
|
|
|
- 1, indices, typ_mask_sorted
|
|
|
|
- )
|
|
|
|
|
|
+ typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
|
|
logits[typ_mask] = -float("inf")
|
|
logits[typ_mask] = -float("inf")
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
+
|
|
def _greedy_sample(
|
|
def _greedy_sample(
|
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
@@ -680,12 +694,13 @@ def _sample(
|
|
category_logprobs)
|
|
category_logprobs)
|
|
else:
|
|
else:
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
-
|
|
|
|
|
|
+
|
|
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
|
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
|
|
|
|
|
sample_results = [
|
|
sample_results = [
|
|
- sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
|
|
|
|
- ]
|
|
|
|
|
|
+ sample_results_dict[i]
|
|
|
|
+ for i in range(len(input_metadata.seq_groups))
|
|
|
|
+ ]
|
|
return sample_results
|
|
return sample_results
|
|
|
|
|
|
|
|
|
|
@@ -822,4 +837,4 @@ def _build_sampler_output(
|
|
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
|
|
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
|
|
sampler_output.append(
|
|
sampler_output.append(
|
|
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
|
|
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
|
|
- return sampler_output
|
|
|
|
|
|
+ return sampler_output
|