|
@@ -2,11 +2,13 @@
|
|
|
import itertools
|
|
|
import os
|
|
|
import warnings
|
|
|
+from enum import IntEnum
|
|
|
from math import inf
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+from loguru import logger
|
|
|
|
|
|
import aphrodite._custom_ops as ops
|
|
|
from aphrodite.common.sampling_params import SamplingType
|
|
@@ -36,6 +38,26 @@ APHRODITE_USE_SAMPLING_KERNELS = bool(int(
|
|
|
os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
|
|
|
|
|
|
|
|
|
+class SamplerID(IntEnum):
|
|
|
+ # Mirror these in aphrodite/common/sampling_params.py
|
|
|
+ # Values out of order to keep backwards compatibility
|
|
|
+ # with Koboldcpp values
|
|
|
+ DRY = 7
|
|
|
+ PENALTIES = 6
|
|
|
+ NO_REPEAT_NGRAM = 8
|
|
|
+ TEMPERATURE = 5
|
|
|
+ TOP_NSIGMA = 9
|
|
|
+ TOP_P_TOP_K = 0
|
|
|
+ TOP_A = 1
|
|
|
+ MIN_P = 2
|
|
|
+ TFS = 3
|
|
|
+ ETA_CUTOFF = 10
|
|
|
+ EPSILON_CUTOFF = 11
|
|
|
+ TYPICAL_P = 4
|
|
|
+ QUADRATIC = 12
|
|
|
+ XTC = 13
|
|
|
+
|
|
|
+
|
|
|
class Sampler(nn.Module):
|
|
|
"""Samples the next tokens from the model's outputs.
|
|
|
|
|
@@ -151,90 +173,242 @@ class Sampler(nn.Module):
|
|
|
do_temp_last = self._do_temp_last
|
|
|
|
|
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
|
|
-
|
|
|
- if do_dry:
|
|
|
- logits = _apply_dry(
|
|
|
- logits,
|
|
|
- sampling_tensors.prompt_tokens,
|
|
|
- sampling_tensors.dry_multipliers,
|
|
|
- sampling_tensors.dry_bases,
|
|
|
- sampling_tensors.dry_allowed_lengths,
|
|
|
- sampling_tensors.dry_sequence_breaker_ids
|
|
|
- )
|
|
|
-
|
|
|
- # Apply presence and frequency penalties.
|
|
|
- if do_penalties:
|
|
|
- logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
|
- sampling_tensors.output_tokens,
|
|
|
- sampling_tensors.presence_penalties,
|
|
|
- sampling_tensors.frequency_penalties,
|
|
|
- sampling_tensors.repetition_penalties)
|
|
|
-
|
|
|
- if do_no_repeat_ngrams:
|
|
|
- logits = _apply_no_repeat_ngram(
|
|
|
- logits,
|
|
|
- sampling_tensors.prompt_tokens,
|
|
|
- sampling_tensors.no_repeat_ngram_sizes)
|
|
|
-
|
|
|
- # Apply temperature scaling if not doing temp_last.
|
|
|
- if do_temperatures and not do_temp_last:
|
|
|
- _apply_temperatures(logits, sampling_tensors.temperatures,
|
|
|
- sampling_tensors.dynatemp_mins,
|
|
|
- sampling_tensors.dynatemp_maxs,
|
|
|
- sampling_tensors.dynatemp_exps)
|
|
|
-
|
|
|
- if do_nsigmas:
|
|
|
- logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
|
|
|
-
|
|
|
- if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
|
|
|
- logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
|
|
- sampling_tensors.top_ks)
|
|
|
-
|
|
|
- if do_top_as:
|
|
|
- logits = _apply_top_a(logits, sampling_tensors.top_as)
|
|
|
-
|
|
|
- if do_min_p:
|
|
|
- logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
|
|
-
|
|
|
- if do_tfss:
|
|
|
- logits = _apply_tfs(logits, sampling_tensors.tfss)
|
|
|
-
|
|
|
- if do_eta_cutoffs:
|
|
|
- logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
|
|
|
-
|
|
|
- if do_epsilon_cutoffs:
|
|
|
- logits = _apply_epsilon_cutoff(logits,
|
|
|
- sampling_tensors.epsilon_cutoffs)
|
|
|
-
|
|
|
- if do_typical_ps:
|
|
|
- logits = _apply_typical_sampling(logits,
|
|
|
- sampling_tensors.typical_ps)
|
|
|
-
|
|
|
- if do_quadratic:
|
|
|
- logits = _apply_quadratic_sampling(
|
|
|
- logits, sampling_tensors.smoothing_factors,
|
|
|
- sampling_tensors.smoothing_curves)
|
|
|
-
|
|
|
- if do_xtc:
|
|
|
- logits = _apply_xtc_sampling(
|
|
|
- logits, sampling_tensors.xtc_thresholds,
|
|
|
- sampling_tensors.xtc_probabilities)
|
|
|
-
|
|
|
- if do_temperatures and do_temp_last:
|
|
|
- _apply_temperatures(logits, sampling_tensors.temperatures,
|
|
|
- sampling_tensors.dynatemp_mins,
|
|
|
- sampling_tensors.dynatemp_maxs,
|
|
|
- sampling_tensors.dynatemp_exps)
|
|
|
-
|
|
|
banned_tokens = _get_custom_token_bans(sampling_metadata)
|
|
|
logits = _apply_token_bans(logits, banned_tokens)
|
|
|
|
|
|
+ sampler_order = None
|
|
|
+ if sampling_metadata.seq_groups:
|
|
|
+ sampler_order = sampling_metadata.seq_groups[
|
|
|
+ 0].sampling_params.sampler_priority
|
|
|
+
|
|
|
+ # Warn if both custom order and temp_last are specified
|
|
|
+ if sampler_order is not None and do_temp_last:
|
|
|
+ logger.warning(
|
|
|
+ "Both sampler_priority and temperature_last=True "
|
|
|
+ "were specified. Using custom sampler_priority order "
|
|
|
+ "and ignoring temperature_last.")
|
|
|
+
|
|
|
+ if sampler_order is None:
|
|
|
+ default_order = [
|
|
|
+ SamplerID.DRY,
|
|
|
+ SamplerID.PENALTIES,
|
|
|
+ SamplerID.NO_REPEAT_NGRAM,
|
|
|
+ SamplerID.TEMPERATURE,
|
|
|
+ SamplerID.TOP_NSIGMA,
|
|
|
+ SamplerID.TOP_P_TOP_K,
|
|
|
+ SamplerID.TOP_A,
|
|
|
+ SamplerID.MIN_P,
|
|
|
+ SamplerID.TFS,
|
|
|
+ SamplerID.ETA_CUTOFF,
|
|
|
+ SamplerID.EPSILON_CUTOFF,
|
|
|
+ SamplerID.TYPICAL_P,
|
|
|
+ SamplerID.QUADRATIC,
|
|
|
+ SamplerID.XTC,
|
|
|
+ ]
|
|
|
+
|
|
|
+ sampler_order = []
|
|
|
+ for sampler_id in default_order:
|
|
|
+ if sampler_id == SamplerID.TEMPERATURE and do_temp_last:
|
|
|
+ continue
|
|
|
+ sampler_order.append(sampler_id)
|
|
|
+
|
|
|
+ if sampler_id == SamplerID.XTC and do_temp_last:
|
|
|
+ sampler_order.append(SamplerID.TEMPERATURE)
|
|
|
+
|
|
|
+ if sampling_metadata.seq_groups and sampling_metadata.seq_groups[
|
|
|
+ 0].is_prompt:
|
|
|
+ logger.debug("Sampler execution order: ")
|
|
|
+ for i, sampler_id in enumerate(sampler_order, 1):
|
|
|
+ logger.debug(f"{i}. {SamplerID(sampler_id).name}")
|
|
|
+
|
|
|
+ enabled_samplers = []
|
|
|
+ # ruff: noqa: E701
|
|
|
+ if do_penalties: enabled_samplers.append("PENALTIES")
|
|
|
+ if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM")
|
|
|
+ if do_temperatures: enabled_samplers.append("TEMPERATURE")
|
|
|
+ if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K")
|
|
|
+ if do_top_as: enabled_samplers.append("TOP_A")
|
|
|
+ if do_min_p: enabled_samplers.append("MIN_P")
|
|
|
+ if do_tfss: enabled_samplers.append("TFS")
|
|
|
+ if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF")
|
|
|
+ if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF")
|
|
|
+ if do_typical_ps: enabled_samplers.append("TYPICAL_P")
|
|
|
+ if do_quadratic: enabled_samplers.append("QUADRATIC")
|
|
|
+ if do_xtc: enabled_samplers.append("XTC")
|
|
|
+ if do_nsigmas: enabled_samplers.append("TOP_NSIGMA")
|
|
|
+ if do_dry: enabled_samplers.append("DRY")
|
|
|
+ if do_skew: enabled_samplers.append("SKEW")
|
|
|
+ logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}")
|
|
|
+
|
|
|
+ for sampler_id in sampler_order:
|
|
|
+ if sampler_id == SamplerID.DRY and do_dry:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ f"Applying DRY with dry_multiplier: "
|
|
|
+ f"{sampling_tensors.dry_multipliers}.")
|
|
|
+ logits = _apply_dry(
|
|
|
+ logits,
|
|
|
+ sampling_tensors.prompt_tokens,
|
|
|
+ sampling_tensors.dry_multipliers,
|
|
|
+ sampling_tensors.dry_bases,
|
|
|
+ sampling_tensors.dry_allowed_lengths,
|
|
|
+ sampling_tensors.dry_sequence_breaker_ids)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.PENALTIES and do_penalties:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying penalties with "
|
|
|
+ f"pres_pen: {sampling_tensors.presence_penalties}, "
|
|
|
+ f"freq_pen: {sampling_tensors.frequency_penalties}, "
|
|
|
+ f"rep_pen: {sampling_tensors.repetition_penalties}.")
|
|
|
+ logits = _apply_penalties(
|
|
|
+ logits, sampling_tensors.prompt_tokens,
|
|
|
+ sampling_tensors.output_tokens,
|
|
|
+ sampling_tensors.presence_penalties,
|
|
|
+ sampling_tensors.frequency_penalties,
|
|
|
+ sampling_tensors.repetition_penalties)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \
|
|
|
+ do_no_repeat_ngrams:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying no_repeat_ngram with no_repeat_ngram_size: "
|
|
|
+ f"{sampling_tensors.no_repeat_ngram_sizes}.")
|
|
|
+ logits = _apply_no_repeat_ngram(
|
|
|
+ logits,
|
|
|
+ sampling_tensors.prompt_tokens,
|
|
|
+ sampling_tensors.no_repeat_ngram_sizes)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.TEMPERATURE and do_temperatures:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying temperatures with temperature: "
|
|
|
+ f"{sampling_tensors.temperatures}, "
|
|
|
+ f"dynatemp_min: {sampling_tensors.dynatemp_mins}, "
|
|
|
+ f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, "
|
|
|
+ f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.")
|
|
|
+ _apply_temperatures(
|
|
|
+ logits, sampling_tensors.temperatures,
|
|
|
+ sampling_tensors.dynatemp_mins,
|
|
|
+ sampling_tensors.dynatemp_maxs,
|
|
|
+ sampling_tensors.dynatemp_exps)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Top-Nsigma with nsigma: "
|
|
|
+ f"{sampling_tensors.nsigmas}")
|
|
|
+ logits = _apply_top_nsigma(
|
|
|
+ logits, sampling_tensors.nsigmas)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \
|
|
|
+ not APHRODITE_USE_SAMPLING_KERNELS:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Top-p and Top-k with top-p: "
|
|
|
+ f"{sampling_tensors.top_ps}, top_k: "
|
|
|
+ f"{sampling_tensors.top_ks}.")
|
|
|
+ logits = _apply_top_k_top_p(
|
|
|
+ logits, sampling_tensors.top_ps,
|
|
|
+ sampling_tensors.top_ks)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.TOP_A and do_top_as:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Top-a with Top-a: "
|
|
|
+ f"{sampling_tensors.top_as}.")
|
|
|
+ logits = _apply_top_a(
|
|
|
+ logits, sampling_tensors.top_as)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.MIN_P and do_min_p:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Min-p with Min-p: "
|
|
|
+ f"{sampling_tensors.min_ps}.")
|
|
|
+ logits = _apply_min_p(
|
|
|
+ logits, sampling_tensors.min_ps)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.TFS and do_tfss:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Tail-Free Sampling with tfs: "
|
|
|
+ f"{sampling_tensors.tfss}.")
|
|
|
+ logits = _apply_tfs(
|
|
|
+ logits, sampling_tensors.tfss)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying ETA Cutoff with eta_cutoff: "
|
|
|
+ f"{sampling_tensors.eta_cutoffs}.")
|
|
|
+ logits = _apply_eta_cutoff(
|
|
|
+ logits, sampling_tensors.eta_cutoffs)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Epsilon Cutoff with epsilon_cutoff: "
|
|
|
+ f"{sampling_tensors.epsilon_cutoffs}.")
|
|
|
+ logits = _apply_epsilon_cutoff(
|
|
|
+ logits, sampling_tensors.epsilon_cutoffs)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Locally Typical Sampling with typical_p: "
|
|
|
+ f"{sampling_tensors.typical_ps}.")
|
|
|
+ logits = _apply_typical_sampling(
|
|
|
+ logits, sampling_tensors.typical_ps)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.QUADRATIC and do_quadratic:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Quadratic and Cubic Sampling with "
|
|
|
+ "smoothing_factors: "
|
|
|
+ f"{sampling_tensors.smoothing_factors},"
|
|
|
+ f" smoothing_curves: "
|
|
|
+ f"{sampling_tensors.smoothing_curves}.")
|
|
|
+ logits = _apply_quadratic_sampling(
|
|
|
+ logits, sampling_tensors.smoothing_factors,
|
|
|
+ sampling_tensors.smoothing_curves)
|
|
|
+
|
|
|
+ elif sampler_id == SamplerID.XTC and do_xtc:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Exclude Top Choices sampling with "
|
|
|
+ f"xtc_threshold: {sampling_tensors.xtc_thresholds}, "
|
|
|
+ "xtc_probability: "
|
|
|
+ f"{sampling_tensors.xtc_probabilities}.")
|
|
|
+ logits = _apply_xtc_sampling(
|
|
|
+ logits, sampling_tensors.xtc_thresholds,
|
|
|
+ sampling_tensors.xtc_probabilities)
|
|
|
+
|
|
|
+
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
|
# Compute the probabilities.
|
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
|
|
# skew needs to be applied post-softmax
|
|
|
if do_skew:
|
|
|
+ if (sampling_metadata.seq_groups and
|
|
|
+ sampling_metadata.seq_groups[0].is_prompt):
|
|
|
+ logger.debug(
|
|
|
+ "Applying Skew sampling with skew: "
|
|
|
+ f"{sampling_tensors.skews}.")
|
|
|
# reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
|
|
|
cum_probs = torch.cumsum(probs, dim=-1)
|
|
|
cum_probs = torch.pow(cum_probs, torch.exp(
|