123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284 |
- """A layer that samples the next tokens from the model's outputs."""
- import itertools
- from typing import Dict, List, Optional, Tuple
- import torch
- import torch.nn as nn
- from aphrodite.common.sampling_params import SamplingType
- from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
- SamplerOutput, SequenceGroupOutput,
- SequenceOutput)
- from aphrodite.modeling.layers.ops.sample import sample as sample_triton
- from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
- SamplingTensors,
- SequenceGroupToSample)
- class Sampler(nn.Module):
- """Samples the next tokens from the model's outputs.
- This layer does the following:
- 1. Discard the hidden states that are not used for sampling (i.e., all
- tokens except the final one in each prompt).
- 2. Compute the logits for the next tokens.
- 3. Apply all the different sampler functions in the specified order.
- 4. Sample the next tokens.
- Here, each sequence group within the batch can have different sampling
- parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
- The structure of the logits tensor is coupled with the seq_groups in
- sampling_metadata. Typically, each sequence in each seq_group has one row in
- logits for the next token to be sampled; however, for a seq_group with a
- prompt request with the prompt_logprobs sampling parameter, there are rows
- in logits for each token in the input prompt.
- """
- def __init__(self):
- super().__init__()
- # Whether or not the SamplerOutput should have on-device tensors
- # containing the sampled token ids and probabilities. This is used by
- # speculative decoding.
- self.include_gpu_probs_tensor = False
- def forward(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- """
- Args:
- logits: (num_tokens, vocab_size).
- sampling_metadata: Metadata for sampling.
- """
- assert logits is not None
- _, vocab_size = logits.shape
- # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
- # have not been generated yet
- logits = _apply_min_tokens_penalty(logits, sampling_metadata)
- # Prepare sampling tensors with pinned memory to avoid blocking.
- (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
- do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
- do_typical_ps,
- do_quadratic) = (SamplingTensors.from_sampling_metadata(
- sampling_metadata, vocab_size, logits.device, logits.dtype))
- if do_penalties:
- logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
- sampling_tensors.output_tokens,
- sampling_tensors.presence_penalties,
- sampling_tensors.frequency_penalties,
- sampling_tensors.repetition_penalties)
- if (do_topks or do_topps or do_topas or do_minps):
- logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
- sampling_tensors.top_ks,
- sampling_tensors.top_as,
- sampling_tensors.min_ps)
- if do_tfss:
- logits = _apply_tfs(logits, sampling_tensors.tfss)
- if do_eta_cutoffs:
- logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
- if do_epsilon_cutoffs:
- logits = _apply_epsilon_cutoff(logits,
- sampling_tensors.epsilon_cutoffs)
- if do_typical_ps:
- logits = _apply_typical_sampling(logits,
- sampling_tensors.typical_ps)
- if do_quadratic:
- logits = _apply_quadratic_sampling(
- logits, sampling_tensors.smoothing_factors,
- sampling_tensors.smoothing_curves)
- if do_temperatures:
- logits = _apply_temperature(logits, sampling_tensors.temperatures,
- sampling_tensors.dynatemp_mins,
- sampling_tensors.dynatemp_maxs,
- sampling_tensors.dynatemp_exps)
- banned_tokens = _get_custom_token_bans(sampling_metadata)
- # assert len(banned_tokens) == logits.shape[0]
- logits = _apply_token_bans(logits, banned_tokens)
- # We use float32 for probabilities and log probabilities.
- # Compute the probabilities.
- probs = torch.softmax(logits, dim=-1, dtype=torch.float)
- # Compute the log probabilities.
- logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
- # Sample the next tokens.
- sample_results, maybe_sampled_tokens_tensor = _sample(
- probs,
- logprobs,
- sampling_metadata,
- sampling_tensors,
- include_gpu_probs_tensor=self.include_gpu_probs_tensor,
- modify_greedy_probs=self._should_modify_greedy_probs_inplace,
- )
- if self.include_gpu_probs_tensor:
- assert maybe_sampled_tokens_tensor is not None
- sampled_tokens_tensor = maybe_sampled_tokens_tensor
- on_device_tensors = (probs, sampled_tokens_tensor)
- else:
- on_device_tensors = None
- # Get the logprobs query results.
- prompt_logprobs, sample_logprobs = _get_logprobs(
- logprobs, sampling_metadata, sample_results)
- return _build_sampler_output(sample_results,
- sampling_metadata,
- prompt_logprobs,
- sample_logprobs,
- on_device_tensors=on_device_tensors)
- @property
- def _should_modify_greedy_probs_inplace(self) -> bool:
- """Whether or not the sampler should modify the probability distribution
- of greedily-sampled tokens such that multinomial sampling would sample
- the greedily-sampled token.
- In other words, if True then we set the probability of the greedily-
- sampled token to 1.
- This is used by speculative decoding, which requires that the sampling
- method be encoded into the probability distribution.
- """
- # Modify greedy probs if include_gpu_probs_tensor is set.
- return self.include_gpu_probs_tensor
- def _get_bin_counts_and_mask(
- tokens: torch.Tensor,
- vocab_size: int,
- num_seqs: int,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Compute the bin counts for the tokens.
- # vocab_size + 1 for padding.
- bin_counts = torch.zeros((num_seqs, vocab_size + 1),
- dtype=torch.long,
- device=tokens.device)
- bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
- bin_counts = bin_counts[:, :vocab_size]
- mask = bin_counts > 0
- return bin_counts, mask
- def _get_custom_token_bans(
- sampling_metadata: SamplingMetadata) -> List[List[int]]:
- assert sampling_metadata.seq_groups is not None
- banned_tokens: List[List[int]] = []
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- sampling_params = sampling_metadata.seq_groups[i].sampling_params
- seq_ids = seq_group.seq_ids
- custom_token_bans = sampling_params.custom_token_bans
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- prompt_len = len(seq_group.prompt_logprob_indices)
- banned_tokens += [custom_token_bans] * (prompt_len - 1)
- banned_tokens += [custom_token_bans] * len(seq_ids)
- return banned_tokens
- def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
- output_tokens_tensor: torch.Tensor,
- presence_penalties: torch.Tensor,
- frequency_penalties: torch.Tensor,
- repetition_penalties: torch.Tensor) -> torch.Tensor:
- num_seqs, vocab_size = logits.shape
- _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
- num_seqs)
- output_bin_counts, output_mask = _get_bin_counts_and_mask(
- output_tokens_tensor, vocab_size, num_seqs)
- repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
- repetition_penalties[~(prompt_mask | output_mask)] = 1.0
- logits = torch.where(logits > 0, logits / repetition_penalties,
- logits * repetition_penalties)
- # We follow the definition in OpenAI API.
- # Refer to https://platform.openai.com/docs/api-reference/parameter-details
- logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
- logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
- return logits
- def _apply_token_bans(logits: torch.Tensor,
- banned_tokens: List[List[int]]) -> torch.Tensor:
- for i, banned_token_ids in enumerate(banned_tokens):
- if not banned_token_ids:
- continue
- logits[i, banned_token_ids] = -float("inf")
- return logits
- def _apply_min_tokens_penalty(
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> torch.Tensor:
- """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
- have not been generated yet
- """
- # list of indices in logits that will be set to -inf
- logits_to_penalize = []
- logits_applied = 0
- for seq_group in sampling_metadata.seq_groups:
- seq_ids = seq_group.seq_ids
- sampling_params = seq_group.sampling_params
- sample_indices = seq_group.sample_indices
- logits_applied += len(sample_indices) + len(
- seq_group.prompt_logprob_indices)
- if not seq_group.do_sample:
- continue
- start_idx = sample_indices[0]
- min_tokens = sampling_params.min_tokens
- if min_tokens > 0:
- seqs_to_penalize = []
- for i, seq_id in enumerate(seq_ids):
- seq_data = seq_group.seq_data[seq_id]
- if len(seq_data.output_token_ids) < min_tokens:
- seqs_to_penalize.append(i)
- if seqs_to_penalize:
- # convert to the index into logits
- seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
- # use set() to remove any duplicates
- token_ids_to_penalize = set(sampling_params.stop_token_ids +
- [sampling_params.eos_token_id])
- # itertools.product pairs each seq index with every token id
- logits_to_penalize.extend(
- itertools.product(seqs_to_penalize, token_ids_to_penalize))
- if logits_to_penalize:
- # use zip and * to group indices along each dimension
- # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
- logits[tuple(zip(*logits_to_penalize))] = -float("inf")
- # verifies that no rows in logits were missed unexpectedly
- assert logits_applied == logits.shape[0]
- return logits
- def _apply_alphabet_soup(
- logits: torch.Tensor,
- p: torch.Tensor,
- k: torch.Tensor,
- a: torch.Tensor,
- m: torch.Tensor,
- ) -> torch.Tensor:
- logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
- # Apply top-p, min-p and top-a.
- probs_sort = logits_sort.softmax(dim=-1)
- probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
- min_p_thresholds = probs_sort[:, 0] * m
- top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
- threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
- mask = (probs_sort < threshold.unsqueeze(1)
- ) # Cull logits below the top-a threshold
- mask.logical_or_(
- probs_sum >
- p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
- mask[:, 0] = False # Guarantee at least one token is pickable
- logits_sort[mask] = -float("inf")
- # Apply top-k.
- # Create a mask for the top-k elements.
- top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
- top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
- top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
- # Final mask.
- mask = (mask | top_k_mask)
- logits_sort.masked_fill_(mask, -float("inf"))
- # Re-sort the probabilities.
- src = torch.arange(logits_idx.shape[-1],
- device=logits_idx.device).expand_as(logits_idx)
- logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
- index=logits_idx,
- src=src)
- logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
- return logits
- def _apply_tfs(
- logits: torch.Tensor,
- tfs: torch.Tensor,
- ) -> torch.Tensor:
- logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
- d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
- normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
- curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
- tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
- tfs_mask = torch.cat(
- (
- torch.zeros(
- logits.shape[0], 1, dtype=torch.bool, device=logits.device),
- tfs_mask,
- torch.ones(
- logits.shape[0], 1, dtype=torch.bool, device=logits.device),
- ),
- dim=-1,
- )
- logits_sort[tfs_mask] = -float("inf")
- logits = torch.gather(logits_sort,
- dim=-1,
- index=torch.argsort(logits_idx, dim=-1))
- return logits
- def _apply_eta_cutoff(
- logits: torch.Tensor,
- eta_cutoff: torch.Tensor,
- ) -> torch.Tensor:
- shifted_logits = torch.log_softmax(logits, dim=-1)
- probs = shifted_logits.exp()
- neg_entropy = (probs * shifted_logits).nansum(dim=-1)
- eps = torch.min(eta_cutoff,
- torch.sqrt(eta_cutoff) *
- torch.exp(neg_entropy)).unsqueeze(dim=1)
- eta_mask = probs < eps
- # guard against nulling out all the logits
- top_idx = torch.argmax(probs, dim=1, keepdim=True)
- eta_mask.scatter_(dim=1, index=top_idx, value=False)
- logits[eta_mask] = -float("inf")
- return logits
- def _apply_epsilon_cutoff(
- logits: torch.Tensor,
- epsilon_cutoff: torch.Tensor,
- ) -> torch.Tensor:
- probs = logits.softmax(dim=-1)
- eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
- # guard against nulling out all the logits
- top_idx = torch.argmax(probs, dim=1, keepdim=True)
- eps_mask.scatter_(dim=1, index=top_idx, value=False)
- logits[eps_mask] = -float("inf")
- return logits
- def _apply_typical_sampling(
- logits: torch.Tensor,
- typical_p: torch.Tensor,
- ) -> torch.Tensor:
- shifted_logits = torch.log_softmax(logits, dim=-1)
- probs = shifted_logits.exp()
- neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
- surprisal_deviations = (neg_entropy - shifted_logits).abs()
- _, indices = torch.sort(surprisal_deviations)
- reordered_probs = probs.gather(-1, indices)
- typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
- dim=1)
- min_tokens_to_keep = 1
- # Keep at least min_tokens_to_keep
- typ_mask_sorted[..., :min_tokens_to_keep] = 0
- typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
- logits[typ_mask] = -float("inf")
- return logits
- # pulls double duty for temperature and dynatemp
- def _apply_temperature(
- logits: torch.Tensor,
- temperatures: torch.Tensor,
- dynatemp_mins: torch.Tensor,
- dynatemp_maxs: torch.Tensor,
- dynatemp_exps: torch.Tensor,
- ) -> torch.Tensor:
- dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
- # Check if dynatemp_mask is not empty
- if dynatemp_mask.any():
- dynatemp_mins = dynatemp_mins[dynatemp_mask]
- dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
- dynatemp_exps = dynatemp_exps[dynatemp_mask]
- dynatemp_mins = dynatemp_mins.clamp_(min=0)
- dynatemp_logits = logits[dynatemp_mask]
- dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
- dynatemp_probs = dynatemp_shifted_logits.exp()
- dynatemp_entropies = -(dynatemp_probs *
- dynatemp_shifted_logits).nansum(dim=-1)
- dynatemp_max_entropies = torch.log_(
- (dynatemp_logits > float("-inf")).sum(dim=-1).float())
- normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
- dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
- normalized_entropies.pow_(dynatemp_exps))
- temperatures[dynatemp_mask] = dyn_temp
- temperatures[temperatures == 0.0] = 1.0
- while temperatures.dim() < logits.dim():
- temperatures = temperatures.unsqueeze(-1)
- logits = logits.div(temperatures)
- return logits
- def _apply_quadratic_sampling(
- logits: torch.Tensor,
- smoothing_factor: torch.Tensor,
- smoothing_curve: torch.Tensor,
- ) -> torch.Tensor:
- """
- Applies a quadratic transformation to the logits based on the
- provided smoothing factors and curves. The transformation is
- centered around the maximum logit value in the batch.
- The transformation involves a quadratic and cubic term, with the
- cubic term controlled by the smoothing curve. The quadratic term is
- scaled by the smoothing factor, and the cubic term is scaled by the
- product of the smoothing factor and the smoothing curve.
- params:
- logits (torch.Tensor): The logits to be transformed.
- smoothing_factors (torch.Tensor): The factors to scale the quadratic
- term in the transformation.
- smoothing_curves (torch.Tensor): The factors to scale the cubic term
- in the transformation.
- returns:
- torch.Tensor: The transformed logits.
- Credits: @kalomaze
- """
- max_logits = logits.max(dim=-1, keepdim=True).values
- diff = logits - max_logits
- smoothing_factor.unsqueeze_(dim=1)
- smoothing_curve.unsqueeze_(dim=1)
- k = (3 - smoothing_curve) / 2
- s = (smoothing_curve - 1) / 2
- mask = smoothing_factor > 0
- mask = mask.flatten()
- transformed_logits = torch.where(
- logits != float('-inf'), -(k * smoothing_factor * diff**2) +
- (s * smoothing_factor * diff**3) + max_logits, logits)
- logits[mask, :] = transformed_logits[mask, :]
- return logits
- def _greedy_sample(
- selected_seq_groups: List[SequenceGroupToSample],
- samples: torch.Tensor,
- ) -> List[Tuple[List[int], List[int]]]:
- """Run greedy sampling on a given samples.
- Args:
- selected_seq_groups: A list of sequence groups batched.
- samples: (num_selected_samples,) A tensor of samples. The length of
- samples could be smaller than selected_seq_groups if
- seq_group.do_sample is False.
- Returns:
- Tuple of (next_token_ids, parent_ids). The length of returned list is
- same as the length of selected_seq_groups. If the corresponding
- seq_group has do_sample=False, tuple contains ([], [])
- """
- samples = samples.tolist()
- sample_idx = 0
- results = []
- for seq_group in selected_seq_groups:
- if not seq_group.do_sample:
- results.append(([], []))
- continue
- seq_ids = seq_group.seq_ids
- num_parent_seqs = len(seq_ids)
- assert num_parent_seqs == 1, (
- "Greedy sampling should have only one seq.")
- parent_ids = list(range(num_parent_seqs))
- next_token_ids = [samples[sample_idx]]
- results.append((next_token_ids, parent_ids))
- sample_idx += num_parent_seqs
- return results
- def _random_sample(
- selected_seq_groups: List[SequenceGroupToSample],
- random_samples: torch.Tensor,
- ) -> List[Tuple[List[int], List[int]]]:
- """Run random sampling on a given samples.
- Args:
- selected_seq_groups: A list of sequence groups batched.
- random_samples: (num_selected_samples,) A tensor of samples. The
- length of samples could be smaller than selected_seq_groups if
- seq_group.do_sample is False.
- Returns:
- Tuple of (next_token_ids, parent_ids). The length of returned list is
- same as the length of selected_seq_groups. If the corresponding
- seq_group has do_sample=False, tuple contains ([], [])
- """
- # Find the maximum best_of value of the prompt phase requests.
- random_samples = random_samples.cpu()
- sample_idx = 0
- results = []
- for seq_group in selected_seq_groups:
- if not seq_group.do_sample:
- results.append(([], []))
- continue
- seq_ids = seq_group.seq_ids
- sampling_params = seq_group.sampling_params
- is_prompt = seq_group.is_prompt
- num_parent_seqs = len(seq_ids)
- if is_prompt:
- # Prompt phase.
- parent_ids = [0] * sampling_params.best_of
- next_token_ids = random_samples[
- sample_idx, :sampling_params.best_of].tolist()
- else:
- # Generation phase.
- parent_ids = list(range(num_parent_seqs))
- next_token_ids = random_samples[sample_idx:sample_idx +
- num_parent_seqs, 0].tolist()
- results.append((next_token_ids, parent_ids))
- sample_idx += num_parent_seqs
- return results
- def _beam_search_sample(
- selected_seq_groups: List[SequenceGroupToSample],
- logprobs: torch.Tensor,
- ) -> List[Tuple[List[int], List[int]]]:
- """Run beam sampling on a given samples.
- Args:
- selected_seq_groups: A list of sequence groups batched.
- logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
- on selected sample indices.
- Returns:
- Tuple of (next_token_ids, parent_ids). The length of returned list is
- same as the length of selected_seq_groups. If the corresponding
- seq_group has do_sample=False, tuple contains ([], [])
- """
- # We sample 2 * beam_width candidates to make sure that with high
- # probability we can get `beam_width` candidates in addition to
- # the finished sequences for the next iteration. See
- # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
- # for details. See also HF reference:
- # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
- #
- # NOTE: Beam search is not vectorized, so its speed can be slower than
- # other sampling methods.
- sample_idx = 0
- results = []
- for seq_group in selected_seq_groups:
- if not seq_group.do_sample:
- results.append(([], []))
- continue
- is_prompt = seq_group.is_prompt
- seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
- num_parent_seqs = len(seq_ids)
- beam_width = sampling_params.best_of
- seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
- if is_prompt:
- # Prompt phase.
- assert num_parent_seqs == 1, (
- "Prompt input should have only one seq.")
- parent_ids = [0] * (2 * beam_width)
- _, next_token_ids = torch.topk(seq_group_logprobs[0],
- 2 * beam_width)
- next_token_ids = next_token_ids.tolist()
- else:
- # Generation phase.
- cumulative_logprobs = [
- seq_group.seq_data[seq_id].cumulative_logprob
- for seq_id in seq_ids
- ]
- cumulative_logprobs = torch.tensor(
- cumulative_logprobs,
- dtype=torch.float,
- device=seq_group_logprobs.device)
- seq_group_logprobs = (seq_group_logprobs +
- cumulative_logprobs.unsqueeze(dim=1))
- _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
- 2 * beam_width)
- topk_ids = topk_ids.tolist()
- vocab_size = seq_group_logprobs.size(-1)
- parent_ids = [i // vocab_size for i in topk_ids]
- next_token_ids = [i % vocab_size for i in topk_ids]
- results.append((next_token_ids, parent_ids))
- sample_idx += num_parent_seqs
- assert sample_idx == logprobs.size(0)
- return results
- # torch.multinomial forces a GPU<->CPU sync.
- # Therefore, we use an optimized implementation instead.
- # Note that we always sample with replacement.
- # probs will be modified in place, but this is fine, as we pass
- # in a copy already.
- def _multinomial(
- probs: torch.Tensor,
- num_samples: int,
- seq_groups: Optional[List[SequenceGroupToSample]] = None,
- ) -> torch.Tensor:
- if num_samples > 1:
- # This is equivalent to torch.repeat_interleaved (which also
- # forces a GPU<->CPU sync).
- # This allows us to do sampling with replacement by creating
- # num_samples copies of each row in the tensor, and then
- # batch sampling the resulting tensor.
- probs = probs[:, None, :].expand(probs.shape[0], num_samples,
- probs.shape[1]).contiguous().view(
- -1, probs.shape[1])
- q = torch.empty_like(probs)
- if seq_groups is None:
- q.exponential_()
- else:
- sample_idx = 0
- for seq_group in seq_groups:
- seq_ids = seq_group.seq_ids
- next_sample_idx = sample_idx + len(seq_ids) * num_samples
- q[sample_idx:next_sample_idx].exponential_(
- generator=seq_group.generator)
- sample_idx = next_sample_idx
- return probs.div_(q).argmax(dim=1).view(-1, num_samples)
- def _sample_with_torch(
- probs: torch.Tensor,
- logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- include_gpu_probs_tensor: bool,
- modify_greedy_probs: bool,
- ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
- categorized_seq_group_ids = {t: [] for t in SamplingType}
- categorized_sample_indices = sampling_metadata.categorized_sample_indices
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- sampling_params = seq_group.sampling_params
- sampling_type = sampling_params.sampling_type
- categorized_seq_group_ids[sampling_type].append(i)
- sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
- sample_metadata = {}
- multinomial_samples = {}
- # Create output tensor for sampled token ids.
- if include_gpu_probs_tensor:
- sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
- 1,
- dtype=torch.long,
- device=logprobs.device)
- else:
- sampled_token_ids_tensor = None
- # Counterintuitively, having two loops here is actually faster.
- # The first loop can run without waiting on GPU<->CPU sync.
- for sampling_type in SamplingType:
- sample_indices = categorized_sample_indices[sampling_type][:, 0]
- num_tokens = len(sample_indices)
- if num_tokens == 0:
- continue
- seq_group_id = categorized_seq_group_ids[sampling_type]
- seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
- sample_metadata[sampling_type] = (seq_group_id, seq_groups)
- long_sample_indices = sample_indices.long()
- if sampling_type == SamplingType.GREEDY:
- greedy_samples = torch.argmax(logprobs[long_sample_indices],
- dim=-1)
- if include_gpu_probs_tensor:
- # Store sampled tokens in output tensor.
- sampled_token_ids_tensor[
- long_sample_indices] = greedy_samples.unsqueeze(-1)
- if modify_greedy_probs:
- # If required, modify the probabilities such that sampling from
- # the modified distribution would always sample the argmax
- # token id.
- _modify_greedy_probs_inplace(logprobs, probs,
- long_sample_indices,
- greedy_samples)
- elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- max_best_of_in_batch = 1
- for seq_group in seq_groups:
- if seq_group.is_prompt:
- sampling_params = seq_group.sampling_params
- max_best_of_in_batch = max(max_best_of_in_batch,
- sampling_params.best_of)
- seeded_args = {} if sampling_type == SamplingType.RANDOM else {
- "seq_groups": seq_groups,
- }
- multinomial_samples[sampling_type] = _multinomial(
- probs[long_sample_indices], max_best_of_in_batch,
- **seeded_args)
- if include_gpu_probs_tensor:
- # Store sampled tokens in output tensor.
- sampled_token_ids_tensor[
- long_sample_indices] = multinomial_samples[sampling_type]
- elif sampling_type == SamplingType.BEAM:
- beam_search_logprobs = logprobs[sample_indices]
- else:
- raise ValueError(f"Unsupported sampling type: {sampling_type}")
- # GPU<->CPU sync happens in the loop below.
- # This also converts the sample output to Python objects.
- for sampling_type in SamplingType:
- if sampling_type not in sample_metadata:
- continue
- (seq_group_id, seq_groups) = sample_metadata[sampling_type]
- if sampling_type == SamplingType.GREEDY:
- sample_results = _greedy_sample(seq_groups, greedy_samples)
- elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- sample_results = _random_sample(seq_groups,
- multinomial_samples[sampling_type])
- elif sampling_type == SamplingType.BEAM:
- sample_results = _beam_search_sample(seq_groups,
- beam_search_logprobs)
- sample_results_dict.update(zip(seq_group_id, sample_results))
- sample_results = [
- sample_results_dict.get(i, ([], []))
- for i in range(len(sampling_metadata.seq_groups))
- ]
- return sample_results, sampled_token_ids_tensor
- def _sample_with_triton_kernel(
- probs: torch.Tensor,
- logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- sampling_tensors: SamplingTensors,
- ) -> List[Tuple[List[int], List[int]]]:
- categorized_seq_group_ids = {t: [] for t in SamplingType}
- categorized_sample_indices = sampling_metadata.categorized_sample_indices
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- sampling_params = seq_group.sampling_params
- sampling_type = sampling_params.sampling_type
- categorized_seq_group_ids[sampling_type].append(i)
- sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
- sample_metadata = {}
- max_best_of_in_batch = 1
- # Counterintuitively, having two loops here is actually faster.
- # The first loop can run without waiting on GPU<->CPU sync.
- for sampling_type in SamplingType:
- sample_indices = categorized_sample_indices[sampling_type][:, 0]
- sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
- num_tokens = len(sample_indices)
- if num_tokens == 0:
- continue
- seq_group_id = categorized_seq_group_ids[sampling_type]
- seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
- sample_metadata[sampling_type] = (seq_group_id, seq_groups,
- sample_indices,
- sampled_token_indices)
- if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
- SamplingType.RANDOM_SEED):
- for seq_group in seq_groups:
- if seq_group.is_prompt:
- sampling_params = seq_group.sampling_params
- max_best_of_in_batch = max(max_best_of_in_batch,
- sampling_params.best_of)
- elif sampling_type == SamplingType.BEAM:
- beam_search_logprobs = logprobs[sample_indices]
- else:
- raise ValueError(f"Unsupported sampling type: {sampling_type}")
- sampled_tokens, _, _ = sample_triton(
- probs=probs,
- seeds=sampling_tensors.sampling_seeds,
- max_best_of=max_best_of_in_batch,
- sample_indices=sampling_tensors.sample_indices,
- logprobs=logprobs,
- # don't save logprobs because we have logic for that below
- # TODO: use this instead of the CPU-based logic below
- save_logprobs=False,
- )
- # GPU<->CPU sync happens in the loop below.
- for sampling_type in SamplingType:
- if sampling_type not in sample_metadata:
- continue
- (seq_group_id, seq_groups, sample_indices,
- sampled_token_indices) = sample_metadata[sampling_type]
- if sampling_type == SamplingType.GREEDY:
- sample_results = _greedy_sample(
- seq_groups, sampled_tokens[sampled_token_indices][:, 0])
- elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- sample_results = _random_sample(
- seq_groups, sampled_tokens[sampled_token_indices])
- elif sampling_type == SamplingType.BEAM:
- sample_results = _beam_search_sample(seq_groups,
- beam_search_logprobs)
- sample_results_dict.update(zip(seq_group_id, sample_results))
- sample_results = [
- sample_results_dict.get(i, ([], []))
- for i in range(len(sampling_metadata.seq_groups))
- ]
- return sample_results
- def _sample(
- probs: torch.Tensor, logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
- include_gpu_probs_tensor: bool, modify_greedy_probs: bool
- ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
- """
- Args:
- probs: (num_query_tokens_in_batch, num_vocab)
- logprobs: (num_query_tokens_in_batch, num_vocab)
- sampling_metadata: The metadata for a batch for sampling.
- sampling_tensors: Tensors that include sampling related metadata.
- Returns:
- (next_token_ids, parent_seq_ids) for each seq group in a batch.
- If sampling is skipped, it returns ([], [])
- sampled_token_ids_tensor: A tensor of sampled token ids.
- """
- return _sample_with_torch(
- probs,
- logprobs,
- sampling_metadata,
- include_gpu_probs_tensor=include_gpu_probs_tensor,
- modify_greedy_probs=modify_greedy_probs,
- )
- # TODO: Enable once Triton kernel & associated code is faster.
- # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
- # sampling_tensors)
- def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
- """
- This function calculates the ranks of the chosen tokens in a logprob tensor.
- Args:
- x (torch.Tensor): 2D logprob tensor of shape (N, M)
- where N is the no. of tokens and M is the vocab dim.
- indices (torch.Tensor): List of chosen token indices.
- Returns:
- torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
- Each element in the returned tensor represents the rank
- of the chosen token in the input logprob tensor.
- """
- vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
- indices]
- return (x > vals[:, None]).long().sum(1).add_(1)
- def _get_logprobs(
- logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- sample_results: List[Tuple[List[int], List[int]]],
- ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
- """Return sample lobprobs and prompt logprobs.
- The logic consists of 3 parts.
- - Select indices to compute logprob from, ranks of token ids, and
- the top k token ids from logprobs.
- - Compute prompt logprobs if required.
- - Compute sample logprobs if required.
- Args:
- logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
- logprob per vocab. Sequence groups' query tokens are batched in a
- single flattened tensor. For example, assuming there are N
- seq groups, it is sorted by prefill tokens for seq_group_1 (if
- prompt logprob is enabled), decode tokens for seq_group_1 (if
- sampling is required), prefill tokens for seq_group_2, ...
- sampling_metadata: The sampling metadata.
- sample_results: (num_seq_groups) The tuple of (next_token_ids,
- parent_ids) for each sequence group. When beam search is enabled,
- sample_results can contain different number of seq_ids from
- sampling_metadata.seq_groups. It is because beam search creates
- 2 * BEAM_WIDTH number of samples (whereas there are only up to
- BEAM_WIDTH number of seq_ids).
- Returns:
- A tuple of prompt and sample logprobs per sequence group in a batch.
- """
- # The index of query token to calculate logprobs. It includes both
- # prompt and sample logprob indices.
- query_indices: List[int] = []
- # The next token ids to get the logprob value from.
- next_token_ids: List[int] = []
- # The largest requested number of logprobs. We find logprobs as many as the
- # largest num logprobs in this API.
- largest_num_logprobs = 1
- # Select indices to compute logprob from, ranks of token ids, and the top
- # k token ids from logprobs.
- for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
- sample_results):
- sampling_params = seq_group.sampling_params
- # Update indices and tokens for prompt logprobs.
- if (seq_group.is_prompt
- and sampling_params.prompt_logprobs is not None):
- largest_num_logprobs = max(largest_num_logprobs,
- sampling_params.prompt_logprobs)
- next_prompt_tokens = _get_next_prompt_tokens(seq_group)
- query_indices.extend(seq_group.prompt_logprob_indices)
- next_token_ids.extend(next_prompt_tokens)
- # Update indices and next tokenes for sample logprob.
- if seq_group.do_sample:
- token_ids, parent_seq_ids = sample_result
- # NOTE: We cannot directly use sample_indices because
- # sample_indices only contain parent seq_ids of a previous step.
- # The current step may have different number of seq_ids, and
- # we can obtain it from `sample_result[1]`.
- query_idx = seq_group.sample_indices[0]
- query_indices.extend(
- [query_idx + parent_id for parent_id in parent_seq_ids])
- next_token_ids.extend(token_ids)
- if sampling_params.logprobs is not None:
- largest_num_logprobs = max(largest_num_logprobs,
- sampling_params.logprobs)
- assert len(next_token_ids) == len(query_indices)
- if len(query_indices) == 0:
- empty_sampled_logprob = []
- empty_prompt_logprob = None
- return [empty_prompt_logprob], [empty_sampled_logprob]
- query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
- next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
- # (num_selected_query_tokens, num_logprobs). Note that query_indices can
- # contain duplicates if beam search is enabled.
- selected_logprobs = logprobs[[
- query_indices_gpu,
- next_token_ids_gpu,
- ]]
- ranks = _get_ranks(
- logprobs[query_indices_gpu],
- next_token_ids_gpu,
- )
- assert selected_logprobs.shape[0] == ranks.shape[0]
- # Logprobs of topk tokens for a batch of sequence groups.
- # (num_query_tokens_across_batch).
- if largest_num_logprobs > 0:
- top_logprobs, top_token_ids = torch.topk(logprobs,
- largest_num_logprobs,
- dim=-1)
- top_logprobs = top_logprobs.cpu()
- top_token_ids = top_token_ids.cpu()
- else:
- top_logprobs, top_token_ids = None, None
- selected_logprobs = selected_logprobs.cpu()
- ranks = ranks.cpu()
- # Find prompt/sample logprobs.
- prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
- sample_logprobs_per_seq_group: List[SampleLogprobs] = []
- top_logprob_idx = 0
- selected_logprobs_idx = 0
- for seq_group, sample_result in zip(sampling_metadata.seq_groups,
- sample_results):
- (prompt_logprobs, top_logprob_idx,
- selected_logprobs_idx) = _get_prompt_logprob_if_needed(
- seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
- selected_logprobs_idx, top_logprob_idx)
- prompt_logprobs_per_seq_group.append(prompt_logprobs)
- (sampled_logprobs, top_logprob_idx,
- selected_logprobs_idx) = _get_sampled_logprob_if_needed(
- seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
- top_logprobs, selected_logprobs_idx, top_logprob_idx)
- sample_logprobs_per_seq_group.append(sampled_logprobs)
- return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
- def _get_prompt_logprob_if_needed(
- seq_group: SequenceGroupToSample,
- selected_logprobs: torch.Tensor,
- ranks: torch.Tensor,
- top_token_ids: torch.Tensor,
- top_logprobs: torch.Tensor,
- selected_logprobs_idx: int,
- top_logprob_idx: int,
- ):
- """Compute the prompt logprob from a sequence group if needed."""
- sampling_params = seq_group.sampling_params
- is_prompt = seq_group.is_prompt
- # Find prompt logprobs
- prompt_logprobs: Optional[PromptLogprobs] = None
- if (is_prompt and sampling_params.prompt_logprobs is not None):
- prompt_logprobs = []
- num_logprobs = sampling_params.prompt_logprobs
- next_prompt_tokens = _get_next_prompt_tokens(seq_group)
- for token_id in next_prompt_tokens:
- # Calculate the prompt logprob of the real prompt tokens.
- # Use tuple here for performance (to use to_list()).
- # {token_id: (logprob, rank_from_vocab)}
- prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
- token_id: (selected_logprobs[selected_logprobs_idx].item(),
- ranks[selected_logprobs_idx].item())
- }
- # Add top K prompt logprobs along with its rank.
- if num_logprobs > 0:
- prompt_logprobs_dict.update(
- zip(
- top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
- zip(
- top_logprobs[
- top_logprob_idx, :num_logprobs].tolist(),
- # This is ranks. Since top_logprob is sorted,
- # we can just use a range here.
- range(1, num_logprobs + 1))))
- prompt_logprobs.append({
- token_id: Logprob(*logprob_and_rank)
- for token_id, logprob_and_rank in prompt_logprobs_dict.items()
- })
- # + 1 to go to the next prompt token.
- top_logprob_idx += 1
- selected_logprobs_idx += 1
- return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
- def _get_sampled_logprob_if_needed(
- seq_group: SequenceGroupToSample,
- sample_result: Tuple[List[int], List[int]],
- selected_logprobs: torch.Tensor,
- ranks: torch.Tensor,
- top_token_ids: torch.Tensor,
- top_logprobs: torch.Tensor,
- selected_logprobs_idx: int,
- top_logprob_idx: int,
- ):
- """Compute the sample logprob if needed."""
- seq_ids = seq_group.seq_ids
- num_logprobs = seq_group.sampling_params.logprobs
- if num_logprobs is None:
- num_logprobs = 0
- sampled_logprobs: SampleLogprobs = []
- next_token_ids, parent_seq_ids = sample_result
- if seq_group.do_sample:
- assert len(next_token_ids) > 0
- for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
- # Calculate the sample logprob of the real sampled tokens.
- # Use tuple here for performance (to use to_list()).
- # token_id: (logprob, rank_from_vocab)
- sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
- next_token_id:
- (selected_logprobs[selected_logprobs_idx].item(),
- ranks[selected_logprobs_idx].item())
- }
- # +1 to go to the next sampled token. Note that
- # selected_logprobs can contain duplicates unlike top_logprobs
- # when beam search is enabled.
- selected_logprobs_idx += 1
- # Second, add top K logprobs along with its rank.
- if num_logprobs >= 0:
- sampled_logprobs_dict.update(
- zip(
- top_token_ids[top_logprob_idx +
- parent_id, :num_logprobs].tolist(),
- zip(
- top_logprobs[top_logprob_idx +
- parent_id, :num_logprobs].tolist(),
- # This is rank. Since top_logprob is sorted, we
- # can just use a range here.
- range(1, num_logprobs + 1))))
- sampled_logprobs.append({
- token_id: Logprob(*logprob_and_rank)
- for token_id, logprob_and_rank in
- sampled_logprobs_dict.items()
- })
- # There are len(seq_ids) number of sampled tokens for the current
- # sequence group in top_logprobs. Jump to the next seq_group.
- top_logprob_idx += len(seq_ids)
- return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
- def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
- sample_indices: torch.Tensor,
- greedy_samples: torch.Tensor) -> None:
- """Modify the probability distributions of the greedily-sampled tokens such
- that each sampled token has a "probability" of 1.0. This is required by
- speculative decoding, which depends on the sampling method being encoded
- within the probability distribution for correctness.
- # Why do we only need to do this for greedy sampling?
- Aphrodite's sampler performs the following steps for greedy or multinomial
- (random) sampling:
- 1. Get logits from model.
- 2. Modify logits according to per-sequence sampling parameters.
- - Multiply by temperature, top-k and top-p masking, penalize tokens
- according to their frequency, etc.
- 3. Sample a token.
- - Random sampling simply samples from the modified probability
- distribution.
- - Greedy sampling performs `argmax` to obtain the token with the
- highest likelihood.
-
- Ignoring greedy sampling for a moment, we find that the computed probability
- distribution has the following property: we can sample from it independently
- and find that the token sampled by the Sampler has a frequency corresponding
- to how often we see it in our sampling. In other words, for tokens sampled
- with Aphrodite's random SamplingType, the computed probability distribution
- encodes the sampling methodology completely.
- Greedy sampling does not normally have this property. Aphrodite modifies
- logits according to sampling params, then performs `argmax`, then returns
- the sampled token and the computed probability distribution. If we sample
- from the distribution, we'll find the likelihood of the greedily-sampled
- token is not always 1.0.
- Since lossless speculative decoding requires that the sampling methodology
- be encoded within the probability distribution, we are motivated to modify
- the probability distribution such that the sampled token has probability 1
- when speculative decoding is used.
- NOTE: Alternatively, we could use an extremely low temperature to achieve
- greedy sampling using multinomial computation and unite the codepaths. This
- has implications on the overall design of the sampler, e.g. how to record
- accurate logprobs for the user, so this improvement is deferred to later.
- """
- logprobs[sample_indices, :] = -float('inf')
- logprobs[sample_indices, greedy_samples] = 0.0
- probs[sample_indices, :] = 0
- probs[sample_indices, greedy_samples] = 1.0
- def _build_sampler_output(
- sample_results: List[Tuple[List[int], List[int]]],
- sampling_metadata: SamplingMetadata,
- prompt_logprobs: List[Optional[PromptLogprobs]],
- sample_logprobs: List[SampleLogprobs],
- on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
- ) -> SamplerOutput:
- """Construct Python objects with the output of sampling.
- Args:
- on_device_tensors: Tuple containing on-device tensors with the
- probabilities used in sampling and the sampled token ids. This
- allows post-processing without copies to CPU/serialization, e.g. in
- speculative decoding rejection sampling.
- """
- sampler_output = []
- for (seq_group, sample_result, group_prompt_logprobs,
- group_sample_logprobs) in zip(sampling_metadata.seq_groups,
- sample_results, prompt_logprobs,
- sample_logprobs):
- seq_ids = seq_group.seq_ids
- next_token_ids, parent_ids = sample_result
- seq_outputs = []
- for parent_id, next_token_id, logprobs in zip(parent_ids,
- next_token_ids,
- group_sample_logprobs):
- seq_outputs.append(
- SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
- sampler_output.append(
- SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
- # If not specified, store None values in SamplerOutput.
- if on_device_tensors is not None:
- sampled_token_probs, sampled_token_ids = on_device_tensors
- else:
- sampled_token_probs, sampled_token_ids = (None, None)
- return SamplerOutput(
- outputs=sampler_output,
- sampled_token_probs=sampled_token_probs,
- sampled_token_ids=sampled_token_ids,
- )
- def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
- """Get a list of next prompt tokens to compute logprob from a
- given sequence group.
- It is used to compute prompt logprob. Imagine you have logprob for each
- query token. Query token needs to know the next prompt token id to compute
- prompt logprob. This is a helper to obtain next prompt token ids.
- This API has to be used only when the caller knows seq_group is in prefill
- stage.
- Returns:
- A list of next prompt tokens to compute logprob.
- """
- assert seq_group.is_prompt, (
- "Caller should ensure the sequence group is in a prefill stage.")
- seq_ids = seq_group.seq_ids
- subquery_len = seq_group.subquery_len
- assert subquery_len is not None
- # prompt has only 1 seq id.
- assert len(seq_ids) == 1
- seq_data = seq_group.seq_data[seq_ids[0]]
- computed_len = seq_data.get_num_computed_tokens()
- prompt_tokens = seq_data.prompt_token_ids
- # +1 because we are looking for a next prompt token.
- next_token_index_start = computed_len + 1
- next_token_index_end = min(computed_len + subquery_len + 1,
- len(prompt_tokens))
- next_prompt_tokens = prompt_tokens[
- next_token_index_start:next_token_index_end]
- return next_prompt_tokens
- # def _apply_mirostat_v2(logits: torch.Tensor,
- # sampling_tensors: SamplingTensors) -> torch.Tensor:
- # # Reduce our view to just the affected logits
- # logit_view = logits[sampling_tensors.miro_indices]
- # # Calculate surprise value per token
- # # Convert nats to bits for compatibility with ooba/kobold parameters.
- # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
- # # Mask out "too-surprising" tokens (surprisal > mu)
- # mus = sampling_tensors.miro_mus
- # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
- # # Unmask most-likely logit to guarantee a selection.
- # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
- # miro_mask.scatter_(dim=1, index=maxinds, value=False)
- # # Apply logit mask (effectively a top-k filter).
- # logit_view[miro_mask] = -float("inf")
- # # Project logit changes made to the view onto the original.
- # # I think this step might be redundant.
- # logits[sampling_tensors.miro_indices] = logit_view
- # return logits
- # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
- # sample_results: List[Tuple[List[int], List[int]]],
- # sampling_metadata: SamplingMetadata,
- # output_metadata: OutputMetadata) -> None:
- # """Based on whichever token was finally sampled, we calculate the
- # final surprisal values to update the mus.
- # Because a single sequence can have multiple samples, we must fork
- # the mu accordingly."""
- # assert sampling_metadata.seq_groups is not None
- # seqid_to_tokens = {}
- # seqid_to_indices = {}
- # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
- # sample_results):
- # for idx, (token, parent) in enumerate(zip(toks, parents)):
- # seqid_to_tokens.setdefault(sids[parent], []).append(token)
- # seqid_to_indices.setdefault(sids[parent], []).append(idx)
- # seqids = args.miro_seqids
- # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
- # device=logits.device,
- # dtype=torch.long)
- # # Clumsily, we recalculate token surprisals.
- # logits_view = logits[args.miro_indices]
- # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
- # dim=-1,
- # index=picked_tokens) / -math.log(2)
- # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
- # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
- # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
- # nu_mus = mus - (picked_surprise - taus) * etas
- # # Record updated mu values for use in the next iteration
- # # Note how each mu is split into multiple based on the number of samples.
- # for seqid, seq_mus in zip(seqids, nu_mus):
- # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
- # output_metadata.add(seqid, sample_idx, "miro_mu", mu)
|