|
@@ -3,6 +3,7 @@ from typing import Dict, List, Tuple, Optional
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
+import math
|
|
|
|
|
|
from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
OutputMetadata,
|
|
OutputMetadata,
|
|
@@ -137,49 +138,50 @@ def _perform_sampling(
|
|
logits = _apply_logits_processors(logits, sampling_metadata)
|
|
logits = _apply_logits_processors(logits, sampling_metadata)
|
|
|
|
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
- (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
|
|
|
|
- do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
|
|
|
|
- do_typical_ps, do_quadratic,
|
|
|
|
- do_mirostat) = (SamplingTensors.from_sampling_metadata(
|
|
|
|
- sampling_metadata, vocab_size, logits.device, logits.dtype))
|
|
|
|
|
|
+ sampling_tensors = SamplingTensors.from_sampling_metadata(
|
|
|
|
+ sampling_metadata, vocab_size, logits.device, logits.dtype)
|
|
|
|
|
|
- if do_penalties:
|
|
|
|
|
|
+ if sampling_tensors.do_penalties:
|
|
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
sampling_tensors.output_tokens,
|
|
sampling_tensors.output_tokens,
|
|
- sampling_tensors.presence_penalties,
|
|
|
|
- sampling_tensors.frequency_penalties,
|
|
|
|
- sampling_tensors.repetition_penalties)
|
|
|
|
|
|
+ sampling_tensors.pres_penalties,
|
|
|
|
+ sampling_tensors.freq_penalties,
|
|
|
|
+ sampling_tensors.rep_penalties)
|
|
|
|
|
|
- if do_temperatures:
|
|
|
|
|
|
+ if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
|
|
logits = _apply_temperature(logits, sampling_tensors.temperatures,
|
|
logits = _apply_temperature(logits, sampling_tensors.temperatures,
|
|
sampling_tensors.dynatemp_mins,
|
|
sampling_tensors.dynatemp_mins,
|
|
sampling_tensors.dynatemp_maxs,
|
|
sampling_tensors.dynatemp_maxs,
|
|
sampling_tensors.dynatemp_exps)
|
|
sampling_tensors.dynatemp_exps)
|
|
|
|
|
|
- if do_topks or do_topps or do_topas or do_minps:
|
|
|
|
|
|
+ if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
|
|
|
|
+ or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
|
|
logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
|
|
logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
|
|
sampling_tensors.top_ks,
|
|
sampling_tensors.top_ks,
|
|
sampling_tensors.top_as,
|
|
sampling_tensors.top_as,
|
|
sampling_tensors.min_ps)
|
|
sampling_tensors.min_ps)
|
|
- if do_tfss:
|
|
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_tfss:
|
|
logits = _apply_tfs(logits, sampling_tensors.tfss)
|
|
logits = _apply_tfs(logits, sampling_tensors.tfss)
|
|
- if do_eta_cutoffs:
|
|
|
|
|
|
+ if sampling_tensors.do_eta_cutoffs:
|
|
logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
|
|
logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
|
|
- if do_epsilon_cutoffs:
|
|
|
|
|
|
+ if sampling_tensors.do_epsilon_cutoffs:
|
|
logits = _apply_epsilon_cutoff(logits,
|
|
logits = _apply_epsilon_cutoff(logits,
|
|
sampling_tensors.epsilon_cutoffs)
|
|
sampling_tensors.epsilon_cutoffs)
|
|
- if do_typical_ps:
|
|
|
|
|
|
+ if sampling_tensors.do_typical_ps:
|
|
logits = _apply_typical_sampling(logits, sampling_tensors.typical_ps)
|
|
logits = _apply_typical_sampling(logits, sampling_tensors.typical_ps)
|
|
- if do_quadratic:
|
|
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_quadratic:
|
|
logits = _apply_quadratic_sampling(logits,
|
|
logits = _apply_quadratic_sampling(logits,
|
|
|
|
+ sampling_tensors.smoothing_indices,
|
|
sampling_tensors.smoothing_factors,
|
|
sampling_tensors.smoothing_factors,
|
|
sampling_tensors.smoothing_curves)
|
|
sampling_tensors.smoothing_curves)
|
|
|
|
|
|
banned_tokens = _get_custom_token_bans(sampling_metadata)
|
|
banned_tokens = _get_custom_token_bans(sampling_metadata)
|
|
assert len(banned_tokens) == logits.shape[0]
|
|
assert len(banned_tokens) == logits.shape[0]
|
|
logits = _apply_token_bans(logits, banned_tokens)
|
|
logits = _apply_token_bans(logits, banned_tokens)
|
|
- if do_mirostat:
|
|
|
|
- logits = _mirostat(logits, sampling_tensors, output_metadata)
|
|
|
|
|
|
+ if sampling_tensors.do_mirostat:
|
|
|
|
+ logits = _apply_mirostat_v2(logits, sampling_tensors)
|
|
|
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
# We use float32 for probabilities and log probabilities.
|
|
# Compute the probabilities.
|
|
# Compute the probabilities.
|
|
@@ -190,6 +192,10 @@ def _perform_sampling(
|
|
|
|
|
|
# Sample the next tokens.
|
|
# Sample the next tokens.
|
|
sample_results = _sample(probs, logprobs, sampling_metadata)
|
|
sample_results = _sample(probs, logprobs, sampling_metadata)
|
|
|
|
+
|
|
|
|
+ if sampling_tensors.do_mirostat:
|
|
|
|
+ _mirostat_store_args(logits, sampling_tensors, sample_results,
|
|
|
|
+ sampling_metadata, output_metadata)
|
|
# Get the logprobs query results.
|
|
# Get the logprobs query results.
|
|
prompt_logprobs, sample_logprobs = _get_logprobs(logprobs,
|
|
prompt_logprobs, sample_logprobs = _get_logprobs(logprobs,
|
|
sampling_metadata,
|
|
sampling_metadata,
|
|
@@ -239,53 +245,32 @@ def _get_custom_token_bans(
|
|
return banned_tokens
|
|
return banned_tokens
|
|
|
|
|
|
|
|
|
|
-# def _apply_logits_processors(
|
|
|
|
-# logits: torch.Tensor,
|
|
|
|
-# metadata: SamplingMetadata,
|
|
|
|
-# ) -> torch.Tensor:
|
|
|
|
-# seq_offset = 0
|
|
|
|
-# for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
|
|
|
|
-# seq_size = len(seq_ids)
|
|
|
|
-# output_tokens = []
|
|
|
|
-# if (i < metadata.num_prompts
|
|
|
|
-# and sampling_params.prompt_logprobs is not None):
|
|
|
|
-# prompt_seqs = metadata.prompt_lens[i] - 1
|
|
|
|
-# seq_size += prompt_seqs
|
|
|
|
-# output_tokens.extend([[]] * prompt_seqs)
|
|
|
|
-# seq_end = seq_offset + seq_size
|
|
|
|
-
|
|
|
|
-# if sampling_params.logits_processors:
|
|
|
|
-# output_tokens.extend(metadata.seq_data[sid].output_token_ids
|
|
|
|
-# for sid in seq_ids)
|
|
|
|
-# for proc in sampling_params.logits_processors:
|
|
|
|
-# proc(logits[seq_offset:seq_end], output_tokens)
|
|
|
|
-
|
|
|
|
-# seq_offset = seq_end
|
|
|
|
-
|
|
|
|
-# return logits
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def _apply_logits_processors(
|
|
def _apply_logits_processors(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
|
|
+ metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- logits_row_idx = 0
|
|
|
|
- found_logits_processors = False
|
|
|
|
- for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
|
|
|
- logits_processors = sampling_params.logits_processors
|
|
|
|
- if logits_processors:
|
|
|
|
- found_logits_processors = True
|
|
|
|
- for seq_id in seq_ids:
|
|
|
|
- logits_row = logits[logits_row_idx]
|
|
|
|
- token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
|
|
|
- for logits_processor in logits_processors:
|
|
|
|
- logits_row = logits_processor(token_ids, logits_row)
|
|
|
|
- logits[logits_row_idx] = logits_row
|
|
|
|
- logits_row_idx += 1
|
|
|
|
- else:
|
|
|
|
- logits_row_idx += len(seq_ids)
|
|
|
|
- if found_logits_processors:
|
|
|
|
- assert logits_row_idx == logits.shape[0]
|
|
|
|
|
|
+ assert metadata.seq_groups is not None
|
|
|
|
+ assert metadata.prompt_lens is not None
|
|
|
|
+ assert metadata.seq_data is not None
|
|
|
|
+ seq_offset = 0
|
|
|
|
+ for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
|
|
|
|
+ seq_size = len(seq_ids)
|
|
|
|
+ output_tokens = []
|
|
|
|
+ if (i < metadata.num_prompts
|
|
|
|
+ and sampling_params.prompt_logprobs is not None):
|
|
|
|
+ prompt_seqs = metadata.prompt_lens[i] - 1
|
|
|
|
+ seq_size += prompt_seqs
|
|
|
|
+ output_tokens.extend([[]] * prompt_seqs)
|
|
|
|
+ seq_end = seq_offset + seq_size
|
|
|
|
+
|
|
|
|
+ if sampling_params.logits_processors:
|
|
|
|
+ output_tokens.extend(metadata.seq_data[sid].output_token_ids
|
|
|
|
+ for sid in seq_ids)
|
|
|
|
+ for proc in sampling_params.logits_processors:
|
|
|
|
+ proc(logits[seq_offset:seq_end], output_tokens)
|
|
|
|
+
|
|
|
|
+ seq_offset = seq_end
|
|
|
|
+
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
@@ -398,20 +383,19 @@ def _apply_eta_cutoff(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
eta_cutoff: torch.Tensor,
|
|
eta_cutoff: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
|
|
|
|
- device=logits.device) * 1e-4
|
|
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
probs = shifted_logits.exp()
|
|
probs = shifted_logits.exp()
|
|
|
|
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1)
|
|
neg_entropy = (probs * shifted_logits).nansum(dim=-1)
|
|
- eps = torch.min(eta,
|
|
|
|
- torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
|
|
|
|
|
|
+ eps = torch.min(eta_cutoff,
|
|
|
|
+ torch.sqrt(eta_cutoff) *
|
|
|
|
+ torch.exp(neg_entropy)).unsqueeze(dim=1)
|
|
|
|
|
|
eta_mask = probs < eps
|
|
eta_mask = probs < eps
|
|
|
|
|
|
- if torch.all(eta_mask): # guard against nulling out all the logits
|
|
|
|
- topk_prob, _ = torch.max(probs, dim=-1)
|
|
|
|
- eta_mask = probs < topk_prob
|
|
|
|
|
|
+ # guard against nulling out all the logits
|
|
|
|
+ top_idx = torch.argmax(probs, dim=1, keepdim=True)
|
|
|
|
+ eta_mask.scatter_(dim=1, index=top_idx, value=False)
|
|
|
|
|
|
logits[eta_mask] = -float("inf")
|
|
logits[eta_mask] = -float("inf")
|
|
return logits
|
|
return logits
|
|
@@ -421,16 +405,13 @@ def _apply_epsilon_cutoff(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
epsilon_cutoff: torch.Tensor,
|
|
epsilon_cutoff: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- eps = torch.tensor(epsilon_cutoff,
|
|
|
|
- dtype=logits.dtype,
|
|
|
|
- device=logits.device).unsqueeze(dim=1)
|
|
|
|
probs = logits.softmax(dim=-1)
|
|
probs = logits.softmax(dim=-1)
|
|
|
|
|
|
- eps_mask = probs < (eps * 1e-4)
|
|
|
|
|
|
+ eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
|
|
|
|
|
|
- if torch.all(eps_mask): # guard against nulling out all the logits
|
|
|
|
- topk_prob, _ = torch.max(probs, dim=-1)
|
|
|
|
- eps_mask = probs < topk_prob
|
|
|
|
|
|
+ # guard against nulling out all the logits
|
|
|
|
+ top_idx = torch.argmax(probs, dim=1, keepdim=True)
|
|
|
|
+ eps_mask.scatter_(dim=1, index=top_idx, value=False)
|
|
|
|
|
|
logits[eps_mask] = -float("inf")
|
|
logits[eps_mask] = -float("inf")
|
|
return logits
|
|
return logits
|
|
@@ -440,7 +421,6 @@ def _apply_typical_sampling(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
typical_p: torch.Tensor,
|
|
typical_p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
- typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
|
|
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
shifted_logits = torch.log_softmax(logits, dim=-1)
|
|
probs = shifted_logits.exp()
|
|
probs = shifted_logits.exp()
|
|
|
|
|
|
@@ -449,7 +429,8 @@ def _apply_typical_sampling(
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
surprisal_deviations = (neg_entropy - shifted_logits).abs()
|
|
_, indices = torch.sort(surprisal_deviations)
|
|
_, indices = torch.sort(surprisal_deviations)
|
|
reordered_probs = probs.gather(-1, indices)
|
|
reordered_probs = probs.gather(-1, indices)
|
|
- typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
|
|
|
|
|
|
+ typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
|
|
|
|
+ dim=1)
|
|
|
|
|
|
min_tokens_to_keep = 1
|
|
min_tokens_to_keep = 1
|
|
# Keep at least min_tokens_to_keep
|
|
# Keep at least min_tokens_to_keep
|
|
@@ -493,8 +474,9 @@ def _apply_temperature(
|
|
|
|
|
|
def _apply_quadratic_sampling(
|
|
def _apply_quadratic_sampling(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
- smoothing_factors: torch.Tensor,
|
|
|
|
- smoothing_curves: torch.Tensor,
|
|
|
|
|
|
+ indices: torch.Tensor,
|
|
|
|
+ factors: torch.Tensor,
|
|
|
|
+ curves: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
"""
|
|
"""
|
|
Applies a quadratic transformation to the logits based on the
|
|
Applies a quadratic transformation to the logits based on the
|
|
@@ -508,9 +490,11 @@ def _apply_quadratic_sampling(
|
|
|
|
|
|
params:
|
|
params:
|
|
logits (torch.Tensor): The logits to be transformed.
|
|
logits (torch.Tensor): The logits to be transformed.
|
|
- smoothing_factors (torch.Tensor): The factors to scale the quadratic
|
|
|
|
|
|
+ indices (torch.Tensor): Indices to project `logits` down to
|
|
|
|
+ the other tensor's lengths.
|
|
|
|
+ factors (torch.Tensor): The factors to scale the quadratic
|
|
term in the transformation.
|
|
term in the transformation.
|
|
- smoothing_curves (torch.Tensor): The factors to scale the cubic term
|
|
|
|
|
|
+ curves (torch.Tensor): The factors to scale the cubic term
|
|
in the transformation.
|
|
in the transformation.
|
|
|
|
|
|
returns:
|
|
returns:
|
|
@@ -518,20 +502,20 @@ def _apply_quadratic_sampling(
|
|
|
|
|
|
Credits: @kalomaze
|
|
Credits: @kalomaze
|
|
"""
|
|
"""
|
|
- max_logits = logits.max(dim=-1, keepdim=True).values
|
|
|
|
- diff = logits - max_logits
|
|
|
|
- smoothing_factors.unsqueeze_(dim=1)
|
|
|
|
- smoothing_curves.unsqueeze_(dim=1)
|
|
|
|
-
|
|
|
|
- k = (3 - smoothing_curves) / 2
|
|
|
|
- s = (smoothing_curves - 1) / 2
|
|
|
|
-
|
|
|
|
- mask = smoothing_factors > 0
|
|
|
|
- mask = mask.flatten()
|
|
|
|
- transformed_logits = torch.where(
|
|
|
|
- logits != float('-inf'), -(k * smoothing_factors * diff**2) +
|
|
|
|
- (s * smoothing_factors * diff**3) + max_logits, logits)
|
|
|
|
- logits[mask, :] = transformed_logits[mask, :]
|
|
|
|
|
|
+ factors.unsqueeze_(dim=1)
|
|
|
|
+ curves.unsqueeze_(dim=1)
|
|
|
|
+ k = factors * (3 - curves) / 2
|
|
|
|
+ s = factors * (curves - 1) / 2
|
|
|
|
+
|
|
|
|
+ quadlogits = logits[indices] # project to only relevant logits
|
|
|
|
+ max_logits = quadlogits.max(dim=-1, keepdim=True).values
|
|
|
|
+
|
|
|
|
+ # Construct the delta from each logit to its new value
|
|
|
|
+ diff = quadlogits - max_logits
|
|
|
|
+ diff -= diff**2 * (s * diff - k)
|
|
|
|
+ diff[diff != diff] = 0 # Eliminate NaNs from infs
|
|
|
|
+
|
|
|
|
+ logits[indices] -= diff
|
|
return logits
|
|
return logits
|
|
|
|
|
|
|
|
|
|
@@ -539,7 +523,6 @@ def _greedy_sample(
|
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
|
samples: torch.Tensor,
|
|
samples: torch.Tensor,
|
|
) -> List[Tuple[List[int], List[int]]]:
|
|
) -> List[Tuple[List[int], List[int]]]:
|
|
- samples = samples.tolist()
|
|
|
|
sample_idx = 0
|
|
sample_idx = 0
|
|
results = []
|
|
results = []
|
|
for seq_group in selected_seq_groups:
|
|
for seq_group in selected_seq_groups:
|
|
@@ -548,7 +531,7 @@ def _greedy_sample(
|
|
assert num_parent_seqs == 1, (
|
|
assert num_parent_seqs == 1, (
|
|
"Greedy sampling should have only one seq.")
|
|
"Greedy sampling should have only one seq.")
|
|
parent_ids = list(range(num_parent_seqs))
|
|
parent_ids = list(range(num_parent_seqs))
|
|
- next_token_ids = [samples[sample_idx]]
|
|
|
|
|
|
+ next_token_ids = [samples[sample_idx].item()]
|
|
results.append((next_token_ids, parent_ids))
|
|
results.append((next_token_ids, parent_ids))
|
|
sample_idx += num_parent_seqs
|
|
sample_idx += num_parent_seqs
|
|
return results
|
|
return results
|
|
@@ -671,6 +654,10 @@ def _sample(
|
|
logprobs: torch.Tensor,
|
|
logprobs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> List[Tuple[List[int], List[int]]]:
|
|
) -> List[Tuple[List[int], List[int]]]:
|
|
|
|
+ """Returns list of (selected_tokens, parent_seq_ids) tuples
|
|
|
|
+ corresponding to sampling_metadata.seq_groups."""
|
|
|
|
+ assert sampling_metadata.seq_groups is not None
|
|
|
|
+ assert sampling_metadata.categorized_sample_indices is not None
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
@@ -860,92 +847,88 @@ def _build_sampler_output(
|
|
sample_logprobs: List[SampleLogprobs],
|
|
sample_logprobs: List[SampleLogprobs],
|
|
output_metadata: OutputMetadata,
|
|
output_metadata: OutputMetadata,
|
|
) -> SamplerOutput:
|
|
) -> SamplerOutput:
|
|
|
|
+ assert sampling_metadata.seq_groups is not None
|
|
sampler_output = []
|
|
sampler_output = []
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
|
sample_results, prompt_logprobs,
|
|
sample_results, prompt_logprobs,
|
|
sample_logprobs):
|
|
sample_logprobs):
|
|
seq_ids, _ = seq_group
|
|
seq_ids, _ = seq_group
|
|
- next_token_ids, parent_ids = sample_result
|
|
|
|
- seq_outputs = []
|
|
|
|
- for parent_id, next_token_id, logprobs in zip(parent_ids,
|
|
|
|
- next_token_ids,
|
|
|
|
- group_sample_logprobs):
|
|
|
|
- seq_outputs.append(
|
|
|
|
- SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
|
|
|
|
- output_metadata.get(seq_ids[parent_id])))
|
|
|
|
|
|
+ seq_outputs = [
|
|
|
|
+ SequenceOutput(seq_ids[parent_id], token_id, logprobs,
|
|
|
|
+ output_metadata.get(seq_ids[parent_id], idx))
|
|
|
|
+ for idx, (token_id, parent_id, logprobs) in enumerate(
|
|
|
|
+ zip(*sample_result, group_sample_logprobs))
|
|
|
|
+ ]
|
|
|
|
+
|
|
sampler_output.append(
|
|
sampler_output.append(
|
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
|
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
|
return sampler_output
|
|
return sampler_output
|
|
|
|
|
|
|
|
|
|
-def _miro_store_args(seqids: List[int], mus: List[float],
|
|
|
|
- output_metadata: OutputMetadata) -> None:
|
|
|
|
- for sid, mu in zip(seqids,
|
|
|
|
- mus.tolist()): # tolist might be premature optimization
|
|
|
|
- output_metadata.add(sid, "miro_mu", mu)
|
|
|
|
|
|
+def _apply_mirostat_v2(logits: torch.Tensor,
|
|
|
|
+ sampling_tensors: SamplingTensors) -> torch.Tensor:
|
|
|
|
+ # Reduce our view to just the affected logits
|
|
|
|
+ logit_view = logits[sampling_tensors.miro_indices]
|
|
|
|
|
|
|
|
+ # Calculate surprise value per token
|
|
|
|
+ # Convert nats to bits for compatibility with ooba/kobold parameters.
|
|
|
|
+ logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
|
|
|
|
|
|
-def _apply_mirostat_v2(
|
|
|
|
- logits: torch.Tensor,
|
|
|
|
- taus: torch.Tensor, # AKA the targeted surprise
|
|
|
|
- etas: torch.Tensor, # AKA the learning rate
|
|
|
|
- mus: torch.
|
|
|
|
- Tensor, # AKA the accumulator that always tries to approach [tau]
|
|
|
|
-) -> torch.Tensor:
|
|
|
|
-
|
|
|
|
- logit_surprise = torch.softmax(
|
|
|
|
- logits, dim=-1).log2_().neg_() # Calculate surprise value per token
|
|
|
|
- # For compatibility with ooba/kobold, done in unit of bits(log base 2)
|
|
|
|
- # not nats(ln).
|
|
|
|
- # Ideally this would be a log_softmax, for numerical stability and
|
|
|
|
- # elegance purposes.
|
|
|
|
- # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
|
|
|
|
-
|
|
|
|
- miro_mask = logit_surprise > mus.unsqueeze(
|
|
|
|
- dim=-1) # Mask out "too-surprising" tokens (above mu)
|
|
|
|
- mininds = torch.argmin(logit_surprise, dim=-1)
|
|
|
|
- miro_mask.scatter_(
|
|
|
|
- 1, mininds.unsqueeze(dim=-1), False
|
|
|
|
- ) # Force at least one outcome to be possible, ideally the most likely one
|
|
|
|
-
|
|
|
|
- logits[miro_mask] = -float("inf")
|
|
|
|
-
|
|
|
|
- probs = torch.softmax(logits, dim=-1,
|
|
|
|
- dtype=logits.dtype) # Get probs, post-mask
|
|
|
|
-
|
|
|
|
- # NOTE: Mirostat updates its `mu` values based on the sample chosen.
|
|
|
|
- # The silly approach here is to just sample it and make the logits one-hot.
|
|
|
|
- # This breaks fine grained seeding, but we don't have that yet.
|
|
|
|
- # TODO: FIX when it gets added
|
|
|
|
- next_token_ids = _multinomial(probs, num_samples=1)
|
|
|
|
-
|
|
|
|
- # Calculation new `mu` values
|
|
|
|
- # NOTE: If we can know the logit values of the PREVIOUS iteration,
|
|
|
|
- # it should be possible to update `mu` before applying mirostat each
|
|
|
|
- # iteration, thus letting us keep _sample as the last thing that happens.
|
|
|
|
- picked_surprises = torch.gather(logit_surprise,
|
|
|
|
- dim=-1,
|
|
|
|
- index=next_token_ids)
|
|
|
|
- eps = picked_surprises.squeeze() - taus
|
|
|
|
- mus.sub_(etas * eps)
|
|
|
|
-
|
|
|
|
- logits.fill_(-float("inf"))
|
|
|
|
- # This value doesn't actually matter, so long as it's not -inf.
|
|
|
|
- # Vectors are now one-hot, after all.
|
|
|
|
- logits.scatter_(1, next_token_ids, 1.0)
|
|
|
|
- return logits
|
|
|
|
|
|
+ # Mask out "too-surprising" tokens (surprisal > mu)
|
|
|
|
+ mus = sampling_tensors.miro_mus
|
|
|
|
+ miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
|
|
|
|
|
|
|
|
+ # Unmask most-likely logit to guarantee a selection.
|
|
|
|
+ maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
|
|
|
|
+ miro_mask.scatter_(dim=1, index=maxinds, value=False)
|
|
|
|
|
|
-def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
|
|
|
|
- output_metadata: OutputMetadata) -> torch.Tensor:
|
|
|
|
- idx = sampling_tensors.miro_indices
|
|
|
|
- seqids = sampling_tensors.miro_seqids
|
|
|
|
- taus = sampling_tensors.miro_taus
|
|
|
|
- etas = sampling_tensors.miro_etas
|
|
|
|
- mus = sampling_tensors.miro_mus
|
|
|
|
|
|
+ # Apply logit mask (effectively a top-k filter).
|
|
|
|
+ logit_view[miro_mask] = -float("inf")
|
|
|
|
|
|
- logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
|
|
|
|
- mus) # mus is an i/o param, :vomit:
|
|
|
|
- _miro_store_args(seqids, mus, output_metadata)
|
|
|
|
|
|
+ # Project logit changes made to the view onto the original.
|
|
|
|
+ # I think this step might be redundant.
|
|
|
|
+ logits[sampling_tensors.miro_indices] = logit_view
|
|
return logits
|
|
return logits
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
|
|
|
|
+ sample_results: List[Tuple[List[int], List[int]]],
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+ output_metadata: OutputMetadata) -> None:
|
|
|
|
+ """Based on whichever token was finally sampled, we calculate the
|
|
|
|
+ final surprisal values to update the mus.
|
|
|
|
+
|
|
|
|
+ Because a single sequence can have multiple samples, we must fork
|
|
|
|
+ the mu accordingly."""
|
|
|
|
+ assert sampling_metadata.seq_groups is not None
|
|
|
|
+ seqid_to_tokens = {}
|
|
|
|
+ seqid_to_indices = {}
|
|
|
|
+ for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
|
|
|
|
+ sample_results):
|
|
|
|
+ for idx, (token, parent) in enumerate(zip(toks, parents)):
|
|
|
|
+ seqid_to_tokens.setdefault(sids[parent], []).append(token)
|
|
|
|
+ seqid_to_indices.setdefault(sids[parent], []).append(idx)
|
|
|
|
+
|
|
|
|
+ seqids = args.miro_seqids
|
|
|
|
+
|
|
|
|
+ picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
|
|
|
|
+ device=logits.device,
|
|
|
|
+ dtype=torch.long)
|
|
|
|
+
|
|
|
|
+ # Clumsily, we recalculate token surprisals.
|
|
|
|
+ logits_view = logits[args.miro_indices]
|
|
|
|
+ picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
|
|
|
|
+ dim=-1,
|
|
|
|
+ index=picked_tokens) / -math.log(2)
|
|
|
|
+
|
|
|
|
+ taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
|
|
|
|
+ etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
|
|
|
|
+ mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
|
|
|
|
+ nu_mus = mus - (picked_surprise - taus) * etas
|
|
|
|
+
|
|
|
|
+ # Record updated mu values for use in the next iteration
|
|
|
|
+ # Note how each mu is split into multiple based on the number of samples.
|
|
|
|
+ for seqid, seq_mus in zip(seqids, nu_mus):
|
|
|
|
+ for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
|
|
|
|
+ output_metadata.add(seqid, sample_idx, "miro_mu", mu)
|