|
@@ -1,4 +1,3 @@
|
|
|
-import random
|
|
|
from array import array
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
@@ -9,14 +8,10 @@ from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
|
from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
|
|
|
from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
|
|
|
is_pin_memory_available,
|
|
|
- make_tensor_with_pad, maybe_expand_dim)
|
|
|
+ make_tensor_with_pad)
|
|
|
from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
|
|
|
-from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
-_SEED_0_REPLACEMENT = 3403598558
|
|
|
-# Some triton sampler related code is guarded before it is ready.
|
|
|
-_USE_TRITON_SAMPLER = False
|
|
|
|
|
|
|
|
|
@dataclass
|
|
@@ -168,11 +163,12 @@ class SamplingMetadata:
|
|
|
target_device=device,
|
|
|
pin_memory=pin_memory)
|
|
|
categorized_sample_indices = {
|
|
|
- t: maybe_expand_dim(
|
|
|
- async_tensor_h2d(seq_ids,
|
|
|
- dtype=torch.int,
|
|
|
- target_device=device,
|
|
|
- pin_memory=pin_memory), 2, 2)
|
|
|
+ t: async_tensor_h2d(
|
|
|
+ seq_ids,
|
|
|
+ dtype=torch.int,
|
|
|
+ target_device=device,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
for t, seq_ids in categorized_sample_indices.items()
|
|
|
}
|
|
|
|
|
@@ -199,8 +195,8 @@ def _prepare_seq_groups(
|
|
|
device: str,
|
|
|
generators: Optional[Dict[str, torch.Generator]] = None,
|
|
|
cache: Optional[SamplingMetadataCache] = None,
|
|
|
-) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
|
|
- SamplingType, List[Tuple[int, int]]], int]:
|
|
|
+) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
|
|
|
+ List[int]], int, ]:
|
|
|
"""Prepare sequence groups and indices for sampling.
|
|
|
|
|
|
Args:
|
|
@@ -231,16 +227,13 @@ def _prepare_seq_groups(
|
|
|
# Sampling type -> (
|
|
|
# indices to sample/prompt logprob within pruned output logits,
|
|
|
# indices to sample within pruned logits)
|
|
|
- categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
|
|
|
+ categorized_sample_indices: Dict[SamplingType, List[int]] = {
|
|
|
t: []
|
|
|
for t in SamplingType
|
|
|
}
|
|
|
# Index of logits to compute logprob. Logits include both prompt logprob
|
|
|
# and sample logprob indices.
|
|
|
logit_idx = 0
|
|
|
- # Index to sample from a sample tensor. It is used by triton sample kernel.
|
|
|
- # See `_sample_with_triton_kernel` for more details.
|
|
|
- sample_idx = 0
|
|
|
# Total number of prompts from given sequence groups.
|
|
|
num_prompts = 0
|
|
|
|
|
@@ -261,10 +254,10 @@ def _prepare_seq_groups(
|
|
|
# If the current seq group is in decode stage, it is None.
|
|
|
seq_len: Optional[int] = None
|
|
|
query_len: Optional[int] = None
|
|
|
- prompt_logprob_indices: List[int] = \
|
|
|
- sample_obj.prompt_logprob_indices if cache is not None else []
|
|
|
- sample_indices: List[int] = \
|
|
|
- sample_obj.sample_indices if cache is not None else []
|
|
|
+ prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
|
|
|
+ if cache is not None else [])
|
|
|
+ sample_indices: List[int] = (sample_obj.sample_indices
|
|
|
+ if cache is not None else [])
|
|
|
do_sample = seq_group_metadata.do_sample
|
|
|
|
|
|
if seq_group_metadata.is_prompt:
|
|
@@ -330,11 +323,7 @@ def _prepare_seq_groups(
|
|
|
if do_sample:
|
|
|
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
|
|
categorized_sample_indices[sampling_params.sampling_type].extend(
|
|
|
- list(
|
|
|
- zip(range(logit_idx, logit_idx + sample_len),
|
|
|
- range(sample_idx, sample_idx + sample_len))))
|
|
|
- logit_idx += sample_len
|
|
|
- sample_idx += sample_len
|
|
|
+ list(range(logit_idx, logit_idx + sample_len)))
|
|
|
|
|
|
if cache is not None:
|
|
|
sample_obj.sampling_params = sampling_params
|
|
@@ -353,7 +342,8 @@ def _prepare_seq_groups(
|
|
|
generator=generator,
|
|
|
is_prompt=is_prompt,
|
|
|
prompt_logprob_indices=list(prompt_logprob_indices),
|
|
|
- sample_indices=list(sample_indices))
|
|
|
+ sample_indices=list(sample_indices),
|
|
|
+ )
|
|
|
|
|
|
seq_groups.append(sample_obj)
|
|
|
|
|
@@ -395,9 +385,6 @@ class SamplingTensors:
|
|
|
dry_sequence_breaker_ids: torch.Tensor
|
|
|
dry_ranges: torch.Tensor
|
|
|
skews: torch.Tensor
|
|
|
- sampling_seeds: torch.Tensor
|
|
|
- sample_indices: torch.Tensor
|
|
|
- extra_seeds: Optional[torch.Tensor]
|
|
|
prompt_tokens: torch.Tensor
|
|
|
output_tokens: torch.Tensor
|
|
|
|
|
@@ -408,16 +395,8 @@ class SamplingTensors:
|
|
|
vocab_size: int,
|
|
|
device: torch.device,
|
|
|
dtype: torch.dtype,
|
|
|
- *,
|
|
|
- extra_seeds_to_generate: int = 0,
|
|
|
- extra_entropy: Optional[Tuple[int, ...]] = None
|
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
bool, bool, bool, bool, bool, bool, bool, bool, bool]:
|
|
|
- """
|
|
|
- extra_seeds_to_generate: extra seeds to generate using the
|
|
|
- user-defined seed for each sequence.
|
|
|
- extra_entropy: extra entropy to use when generating seeds.
|
|
|
- """
|
|
|
prompt_tokens: List[array] = []
|
|
|
output_tokens: List[array] = []
|
|
|
top_ks: List[int] = []
|
|
@@ -442,8 +421,6 @@ class SamplingTensors:
|
|
|
xtc_thresholds: List[float] = []
|
|
|
xtc_probabilities: List[float] = []
|
|
|
nsigmas: List[float] = []
|
|
|
- sampling_seeds: List[List[int]] = []
|
|
|
- sample_indices: List[int] = []
|
|
|
dry_multipliers: List[float] = []
|
|
|
dry_bases: List[float] = []
|
|
|
dry_allowed_lengths: List[int] = []
|
|
@@ -468,13 +445,6 @@ class SamplingTensors:
|
|
|
do_skews = False
|
|
|
do_temp_last = False
|
|
|
|
|
|
- if _USE_TRITON_SAMPLER:
|
|
|
- prompt_best_of: List[int] = []
|
|
|
-
|
|
|
- # We need one base seed per Triton slice.
|
|
|
- seeds_to_generate = (extra_seeds_to_generate +
|
|
|
- get_num_triton_sampler_splits(vocab_size))
|
|
|
-
|
|
|
assert sampling_metadata.seq_groups is not None
|
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
|
seq_ids = seq_group.seq_ids
|
|
@@ -515,7 +485,6 @@ class SamplingTensors:
|
|
|
|
|
|
do_temp_last |= params.temperature_last
|
|
|
|
|
|
- is_prompt = seq_group.is_prompt
|
|
|
wants_prompt_logprobs = params.prompt_logprobs is not None
|
|
|
|
|
|
n_seqs = 0
|
|
@@ -557,28 +526,6 @@ class SamplingTensors:
|
|
|
dry_ranges += [params.dry_range] * n_seqs
|
|
|
skews += [params.skew] * n_seqs
|
|
|
|
|
|
- if _USE_TRITON_SAMPLER:
|
|
|
- if is_prompt:
|
|
|
- prompt_best_of.append(params.best_of)
|
|
|
- query_len = seq_group.query_len
|
|
|
- assert query_len is not None
|
|
|
-
|
|
|
- seed = params.seed
|
|
|
- is_greedy = params.sampling_type == SamplingType.GREEDY
|
|
|
-
|
|
|
- for seq_id in seq_ids:
|
|
|
- seq_data = seq_group.seq_data[seq_id]
|
|
|
- extra_entropy = extra_entropy or ()
|
|
|
- seq_seeds = cls._get_sequence_seeds(
|
|
|
- seed,
|
|
|
- seq_data.get_len(),
|
|
|
- *extra_entropy,
|
|
|
- seq_id,
|
|
|
- seeds_to_generate=seeds_to_generate,
|
|
|
- is_greedy=is_greedy)
|
|
|
- sampling_seeds.append(seq_seeds)
|
|
|
- sample_indices.extend(seq_group.sample_indices)
|
|
|
-
|
|
|
if do_penalties or do_dry or do_no_repeat_ngrams:
|
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
|
seq_ids = seq_group.seq_ids
|
|
@@ -598,43 +545,94 @@ class SamplingTensors:
|
|
|
output_tokens.append(seq_data.output_token_ids_array)
|
|
|
|
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
|
- temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
|
|
|
- temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
|
- presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
- no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
|
|
|
- typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
|
|
|
- xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
|
|
|
- dry_allowed_lengths, dry_sequence_breaker_ids, dry_ranges, skews,
|
|
|
- sampling_seeds, sample_indices, prompt_tokens, output_tokens,
|
|
|
- vocab_size, extra_seeds_to_generate, device, dtype)
|
|
|
- return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
|
|
|
- do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
|
|
|
- do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
|
|
|
- do_quadratic, do_xtc, do_nsigmas, do_dry, do_skews,
|
|
|
- do_temp_last)
|
|
|
+ temperatures,
|
|
|
+ dynatemp_mins,
|
|
|
+ dynatemp_maxs,
|
|
|
+ dynatemp_exps,
|
|
|
+ temperature_lasts,
|
|
|
+ top_ps,
|
|
|
+ top_ks,
|
|
|
+ top_as,
|
|
|
+ min_ps,
|
|
|
+ presence_penalties,
|
|
|
+ frequency_penalties,
|
|
|
+ repetition_penalties,
|
|
|
+ no_repeat_ngram_sizes,
|
|
|
+ tfss,
|
|
|
+ eta_cutoffs,
|
|
|
+ epsilon_cutoffs,
|
|
|
+ typical_ps,
|
|
|
+ smoothing_factors,
|
|
|
+ smoothing_curves,
|
|
|
+ xtc_thresholds,
|
|
|
+ xtc_probabilities,
|
|
|
+ nsigmas,
|
|
|
+ dry_multipliers,
|
|
|
+ dry_bases,
|
|
|
+ dry_allowed_lengths,
|
|
|
+ dry_sequence_breaker_ids,
|
|
|
+ dry_ranges,
|
|
|
+ skews,
|
|
|
+ prompt_tokens,
|
|
|
+ output_tokens,
|
|
|
+ vocab_size,
|
|
|
+ device,
|
|
|
+ dtype)
|
|
|
+ return (
|
|
|
+ sampling_tensors,
|
|
|
+ do_penalties,
|
|
|
+ do_no_repeat_ngrams,
|
|
|
+ do_temperatures,
|
|
|
+ do_top_p_top_k,
|
|
|
+ do_top_as,
|
|
|
+ do_min_p,
|
|
|
+ do_tfss,
|
|
|
+ do_eta_cutoffs,
|
|
|
+ do_epsilon_cutoffs,
|
|
|
+ do_typical_ps,
|
|
|
+ do_quadratic,
|
|
|
+ do_xtc,
|
|
|
+ do_nsigmas,
|
|
|
+ do_dry,
|
|
|
+ do_skews,
|
|
|
+ do_temp_last)
|
|
|
|
|
|
@classmethod
|
|
|
- def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
|
|
|
- dynatemp_maxs: List[float], dynatemp_exps: List[float],
|
|
|
- temperature_lasts: List[bool], top_ps: List[float],
|
|
|
- top_ks: List[int], top_as: List[float],
|
|
|
- min_ps: List[float], presence_penalties: List[float],
|
|
|
- frequency_penalties: List[float],
|
|
|
- repetition_penalties: List[float],
|
|
|
- no_repeat_ngram_sizes: List[int], tfss: List[float],
|
|
|
- eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
- typical_ps: List[float], smoothing_factors: List[float],
|
|
|
- smoothing_curves: List[float], xtc_thresholds: List[float],
|
|
|
- xtc_probabilities: List[float], nsigmas: List[float],
|
|
|
- dry_multipliers: List[float], dry_bases: List[float],
|
|
|
- dry_allowed_lengths: List[int],
|
|
|
- dry_sequence_breaker_ids: List[List[int]],
|
|
|
- dry_ranges: List[int], skews: List[float],
|
|
|
- sampling_seeds: List[List[int]],
|
|
|
- sample_indices: List[int], prompt_tokens: List[array],
|
|
|
- output_tokens: List[array], vocab_size: int,
|
|
|
- extra_seeds_to_generate: int, device: torch.device,
|
|
|
- dtype: torch.dtype) -> "SamplingTensors":
|
|
|
+ def from_lists(
|
|
|
+ cls,
|
|
|
+ temperatures: List[float],
|
|
|
+ dynatemp_mins: List[float],
|
|
|
+ dynatemp_maxs: List[float],
|
|
|
+ dynatemp_exps: List[float],
|
|
|
+ temperature_lasts: List[bool],
|
|
|
+ top_ps: List[float],
|
|
|
+ top_ks: List[int],
|
|
|
+ top_as: List[float],
|
|
|
+ min_ps: List[float],
|
|
|
+ presence_penalties: List[float],
|
|
|
+ frequency_penalties: List[float],
|
|
|
+ repetition_penalties: List[float],
|
|
|
+ no_repeat_ngram_sizes: List[int],
|
|
|
+ tfss: List[float],
|
|
|
+ eta_cutoffs: List[float],
|
|
|
+ epsilon_cutoffs: List[float],
|
|
|
+ typical_ps: List[float],
|
|
|
+ smoothing_factors: List[float],
|
|
|
+ smoothing_curves: List[float],
|
|
|
+ xtc_thresholds: List[float],
|
|
|
+ xtc_probabilities: List[float],
|
|
|
+ nsigmas: List[float],
|
|
|
+ dry_multipliers: List[float],
|
|
|
+ dry_bases: List[float],
|
|
|
+ dry_allowed_lengths: List[int],
|
|
|
+ dry_sequence_breaker_ids: List[List[int]],
|
|
|
+ dry_ranges: List[int],
|
|
|
+ skews: List[float],
|
|
|
+ prompt_tokens: List[array],
|
|
|
+ output_tokens: List[array],
|
|
|
+ vocab_size: int,
|
|
|
+ device: torch.device,
|
|
|
+ dtype: torch.dtype) -> "SamplingTensors":
|
|
|
# Note that the performance will be very bad without
|
|
|
# pinned memory.
|
|
|
pin_memory = is_pin_memory_available()
|
|
@@ -811,34 +809,9 @@ class SamplingTensors:
|
|
|
pin_memory=pin_memory,
|
|
|
)
|
|
|
|
|
|
- sample_indices_t = torch.tensor(
|
|
|
- sample_indices,
|
|
|
- device="cpu",
|
|
|
- dtype=torch.long,
|
|
|
- pin_memory=pin_memory,
|
|
|
- )
|
|
|
- # need to transpose and make contiguous to
|
|
|
- # copy the tensor correctly.
|
|
|
- # [batch_size, n_seeds] -> [n_seeds, batch_size]
|
|
|
- sampling_seeds_t = torch.tensor(
|
|
|
- sampling_seeds,
|
|
|
- device="cpu",
|
|
|
- dtype=torch.long,
|
|
|
- pin_memory=pin_memory,
|
|
|
- ).t().contiguous()
|
|
|
-
|
|
|
# Because the memory is pinned, we can do non-blocking
|
|
|
# transfer to device.
|
|
|
|
|
|
- # How many seeds the sample operation itself will need.
|
|
|
- num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
|
|
|
- sampling_seeds_gpu = sampling_seeds_t.to(device=device,
|
|
|
- non_blocking=True)
|
|
|
- extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
|
|
|
- if not extra_seeds_gpu.numel():
|
|
|
- extra_seeds_gpu = None
|
|
|
- sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
|
|
-
|
|
|
return cls(
|
|
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
|
dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
|
|
@@ -882,38 +855,4 @@ class SamplingTensors:
|
|
|
typical_ps=typical_ps_t.to(device=device, non_blocking=True),
|
|
|
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
|
|
output_tokens=output_t.to(device=device, non_blocking=True),
|
|
|
- sampling_seeds=sampling_seeds_gpu,
|
|
|
- sample_indices=sample_indices_t.to(device=device,
|
|
|
- non_blocking=True),
|
|
|
- extra_seeds=extra_seeds_gpu,
|
|
|
)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def _get_sequence_seeds(
|
|
|
- seed: int|None,
|
|
|
- *extra_entropy: int,
|
|
|
- seeds_to_generate: int,
|
|
|
- is_greedy: bool,
|
|
|
- ):
|
|
|
- """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
|
|
|
- if not is_greedy:
|
|
|
- if seed is None:
|
|
|
- randint_fn = random.randint
|
|
|
- else:
|
|
|
- generator = random.Random(str((seed, ) + extra_entropy))
|
|
|
- randint_fn = generator.randint
|
|
|
- lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
|
|
|
- # If the user/random sets seed = 0 but request should
|
|
|
- # have sampling, we need to change it to something
|
|
|
- # else. We use a constant in that case.
|
|
|
- # This way we don't need to create and load a bool
|
|
|
- # matrix in the sampling kernel, which reduces CPU
|
|
|
- # overhead and latency.
|
|
|
- seq_seeds = [
|
|
|
- randint_fn(lo, hi) or _SEED_0_REPLACEMENT
|
|
|
- for _ in range(seeds_to_generate)
|
|
|
- ]
|
|
|
- else:
|
|
|
- # For the kernel, seed == 0 means greedy decoding.
|
|
|
- seq_seeds = [0] * seeds_to_generate
|
|
|
- return seq_seeds
|