1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042 |
- """A layer that samples the next tokens from the model's outputs."""
- from typing import Dict, List, Tuple, Optional
- import torch
- import torch.nn as nn
- from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
- OutputMetadata,
- SamplingTensors)
- from aphrodite.modeling.megatron.communication_op import (
- tensor_model_parallel_gather)
- from aphrodite.common.sampling_params import SamplingParams, SamplingType
- from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
- SamplerOutput, SequenceData,
- SequenceGroupOutput, SequenceOutput)
- class Sampler(nn.Module):
- """Samples the next tokens from the model's outputs.
- This layer does the following:
- 1. Discard the hidden states that are not used for sampling (i.e., all
- tokens except the final one in each prompt).
- 2. Compute the logits for the next tokens.
- 3. Apply presence and frequency penalties.
- 4. Apply temperature scaling.
- 5. Apply top-p and top-k truncation.
- 6. Sample the next tokens.
- Here, each sequence group within the batch can have different sampling
- parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
- """
- def __init__(self,
- vocab_size: int,
- org_vocab_size: Optional[int] = None) -> None:
- super().__init__()
- self.vocab_size = vocab_size
- # original vocabulary size (without LoRA).
- self.org_vocab_size = org_vocab_size or vocab_size
- def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
- embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
- # Get the logits for the next tokens.
- logits = torch.matmul(hidden_states, embedding.t())
- if embedding_bias is not None:
- logits += embedding_bias
- logits = tensor_model_parallel_gather(logits)
- # Remove paddings in vocab (if any).
- if logits is not None:
- logits = logits[:, :self.org_vocab_size]
- return logits
- def forward(
- self,
- embedding: torch.Tensor,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- embedding_bias: Optional[torch.Tensor] = None,
- ) -> Optional[SamplerOutput]:
- # Get the hidden states that we use for sampling.
- hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
- # Get the logits for the next tokens.
- logits = self._get_logits(hidden_states, embedding, embedding_bias)
- # Only perform sampling in the driver worker.
- # Note: `_get_logits` is still distributed across TP workers because
- # the `embedding` weight is distributed across TP workers.
- # TODO: Change the get_logits part to a separate stage.
- if not sampling_metadata.perform_sampling:
- return None
- assert logits is not None
- _, vocab_size = logits.shape
- output_metadata = OutputMetadata()
- # Apply logits processors (if any)
- logits = _apply_logits_processors(logits, sampling_metadata)
- # Prepare sampling tensors with pinned memory to avoid blocking.
- (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
- do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
- do_typical_ps, do_quadratic,
- do_mirostat) = (SamplingTensors.from_sampling_metadata(
- sampling_metadata, vocab_size, logits.device, logits.dtype))
- if do_penalties:
- logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
- sampling_tensors.output_tokens,
- sampling_tensors.presence_penalties,
- sampling_tensors.frequency_penalties,
- sampling_tensors.repetition_penalties)
- if do_temperatures:
- logits = _apply_temperature(logits, sampling_tensors.temperatures,
- sampling_tensors.dynatemp_mins,
- sampling_tensors.dynatemp_maxs,
- sampling_tensors.dynatemp_exps)
- if do_topks or do_topps or do_topas or do_minps:
- logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
- sampling_tensors.top_ks,
- sampling_tensors.top_as,
- sampling_tensors.min_ps)
- if do_tfss:
- logits = _apply_tfs(logits, sampling_tensors.tfss)
- if do_eta_cutoffs:
- logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
- if do_epsilon_cutoffs:
- logits = _apply_epsilon_cutoff(logits,
- sampling_tensors.epsilon_cutoffs)
- if do_typical_ps:
- logits = _apply_typical_sampling(logits,
- sampling_tensors.typical_ps,
- sampling_tensors.typical_p_sigmas)
- if do_quadratic:
- logits = _apply_quadratic_sampling(
- logits, sampling_tensors.smoothing_factors,
- sampling_tensors.smoothing_curves)
- banned_tokens = _get_custom_token_bans(sampling_metadata)
- assert len(banned_tokens) == logits.shape[0]
- logits = _apply_token_bans(logits, banned_tokens)
- if do_mirostat:
- logits = _mirostat(logits, sampling_tensors, output_metadata)
- # We use float32 for probabilities and log probabilities.
- # Compute the probabilities.
- probs = torch.softmax(logits, dim=-1, dtype=torch.float)
- # Compute the log probabilities.
- # Use log_softmax to ensure numerical stability.
- logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
- # Sample the next tokens.
- sample_results = _sample(probs, logprobs, sampling_metadata)
- # Get the logprobs query results.
- prompt_logprobs, sample_logprobs = _get_logprobs(
- logprobs, sampling_metadata, sample_results)
- return _build_sampler_output(sample_results, sampling_metadata,
- prompt_logprobs, sample_logprobs,
- output_metadata)
- # FIXME: This is a hack for the missing GPU blocks. This should be removed
- # once a proper fix is implemented.
- class QuantSampler(nn.Module):
- """Samples the next tokens from the model's outputs.
- This layer does the following:
- 1. Discard the hidden states that are not used for sampling (i.e., all
- tokens except the final one in each prompt).
- 2. Compute the logits for the next tokens.
- 3. Apply presence and frequency penalties.
- 4. Apply temperature scaling.
- 5. Apply top-p and top-k truncation.
- 6. Sample the next tokens.
- Here, each sequence group within the batch can have different sampling
- parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
- """
- def __init__(self,
- vocab_size: int,
- org_vocab_size: Optional[int] = None) -> None:
- super().__init__()
- self.vocab_size = vocab_size
- # original vocabulary size (without LoRA).
- self.org_vocab_size = org_vocab_size or vocab_size
- def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
- embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
- # Get the logits for the next tokens.
- logits = torch.matmul(hidden_states, embedding.t())
- if embedding_bias is not None:
- logits += embedding_bias
- logits = tensor_model_parallel_gather(logits)
- # Remove paddings in vocab (if any).
- if logits is not None:
- logits = logits[:, :self.org_vocab_size]
- return logits
- def forward(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- # Get the hidden states that we use for sampling.
- logits = _prune_hidden_states(logits, sampling_metadata)
- logits = tensor_model_parallel_gather(logits)
- # Remove paddings in vocab (if any).
- if logits is not None:
- logits = logits[:, :self.vocab_size]
- # Only perform sampling in the driver worker.
- # Note: `_get_logits` is still distributed across TP workers because
- # the `embedding` weight is distributed across TP workers.
- # TODO: Change the get_logits part to a separate stage.
- if not sampling_metadata.perform_sampling:
- return None
- assert logits is not None
- _, vocab_size = logits.shape
- output_metadata = OutputMetadata()
- # Apply logits processors (if any)
- logits = _apply_logits_processors(logits, sampling_metadata)
- # Prepare sampling tensors with pinned memory to avoid blocking.
- (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
- do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
- do_typical_ps, do_quadratic,
- do_mirostat) = (SamplingTensors.from_sampling_metadata(
- sampling_metadata, vocab_size, logits.device, logits.dtype))
- if do_penalties:
- logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
- sampling_tensors.output_tokens,
- sampling_tensors.presence_penalties,
- sampling_tensors.frequency_penalties,
- sampling_tensors.repetition_penalties)
- if do_temperatures:
- logits = _apply_temperature(logits, sampling_tensors.temperatures,
- sampling_tensors.dynatemp_mins,
- sampling_tensors.dynatemp_maxs,
- sampling_tensors.dynatemp_exps)
- if do_topks or do_topps or do_topas or do_minps:
- logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
- sampling_tensors.top_ks,
- sampling_tensors.top_as,
- sampling_tensors.min_ps)
- if do_tfss:
- logits = _apply_tfs(logits, sampling_tensors.tfss)
- if do_eta_cutoffs:
- logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
- if do_epsilon_cutoffs:
- logits = _apply_epsilon_cutoff(logits,
- sampling_tensors.epsilon_cutoffs)
- if do_typical_ps:
- logits = _apply_typical_sampling(logits,
- sampling_tensors.typical_ps)
- if do_quadratic:
- logits = _apply_quadratic_sampling(
- logits, sampling_tensors.smoothing_factors,
- sampling_tensors.smoothing_curves)
- banned_tokens = _get_custom_token_bans(sampling_metadata)
- assert len(banned_tokens) == logits.shape[0]
- logits = _apply_token_bans(logits, banned_tokens)
- if do_mirostat:
- logits = _mirostat(logits, sampling_tensors, output_metadata)
- # We use float32 for probabilities and log probabilities.
- # Compute the probabilities.
- probs = torch.softmax(logits, dim=-1, dtype=torch.float)
- # Compute the log probabilities.
- # Use log_softmax to ensure numerical stability.
- logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
- # Sample the next tokens.
- sample_results = _sample(probs, logprobs, sampling_metadata)
- # Get the logprobs query results.
- prompt_logprobs, sample_logprobs = _get_logprobs(
- logprobs, sampling_metadata, sample_results)
- return _build_sampler_output(sample_results, sampling_metadata,
- prompt_logprobs, sample_logprobs,
- output_metadata)
- def _prune_hidden_states(
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> torch.Tensor:
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- return hidden_states.index_select(0,
- sampling_metadata.selected_token_indices)
- def _get_bin_counts_and_mask(
- tokens: torch.Tensor,
- vocab_size: int,
- num_seqs: int,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Compute the bin counts for the tokens.
- # vocab_size + 1 for padding.
- bin_counts = torch.zeros((num_seqs, vocab_size + 1),
- dtype=torch.long,
- device=tokens.device)
- bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
- bin_counts = bin_counts[:, :vocab_size]
- mask = bin_counts > 0
- return bin_counts, mask
- def _get_custom_token_bans(
- sampling_metadata: SamplingMetadata) -> List[List[int]]:
- banned_tokens: List[List[int]] = []
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
- custom_token_bans = sampling_params.custom_token_bans
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- prompt_len = sampling_metadata.prompt_lens[i]
- banned_tokens += [custom_token_bans] * (prompt_len - 1)
- banned_tokens += [custom_token_bans] * len(seq_ids)
- return banned_tokens
- # def _apply_logits_processors(
- # logits: torch.Tensor,
- # metadata: SamplingMetadata,
- # ) -> torch.Tensor:
- # seq_offset = 0
- # for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
- # seq_size = len(seq_ids)
- # output_tokens = []
- # if (i < metadata.num_prompts
- # and sampling_params.prompt_logprobs is not None):
- # prompt_seqs = metadata.prompt_lens[i] - 1
- # seq_size += prompt_seqs
- # output_tokens.extend([[]] * prompt_seqs)
- # seq_end = seq_offset + seq_size
- # if sampling_params.logits_processors:
- # output_tokens.extend(metadata.seq_data[sid].output_token_ids
- # for sid in seq_ids)
- # for proc in sampling_params.logits_processors:
- # proc(logits[seq_offset:seq_end], output_tokens)
- # seq_offset = seq_end
- # return logits
- def _apply_logits_processors(
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> torch.Tensor:
- logits_row_idx = 0
- found_logits_processors = False
- for seq_ids, sampling_params in sampling_metadata.seq_groups:
- logits_processors = sampling_params.logits_processors
- if logits_processors:
- found_logits_processors = True
- for seq_id in seq_ids:
- logits_row = logits[logits_row_idx]
- token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
- for logits_processor in logits_processors:
- logits_row = logits_processor(token_ids, logits_row)
- logits[logits_row_idx] = logits_row
- logits_row_idx += 1
- else:
- logits_row_idx += len(seq_ids)
- if found_logits_processors:
- assert logits_row_idx == logits.shape[0]
- return logits
- def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
- output_tokens_tensor: torch.Tensor,
- presence_penalties: torch.Tensor,
- frequency_penalties: torch.Tensor,
- repetition_penalties: torch.Tensor) -> torch.Tensor:
- num_seqs, vocab_size = logits.shape
- _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
- num_seqs)
- output_bin_counts, output_mask = _get_bin_counts_and_mask(
- output_tokens_tensor, vocab_size, num_seqs)
- repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
- repetition_penalties[~(prompt_mask | output_mask)] = 1.0
- logits = torch.where(logits > 0, logits / repetition_penalties,
- logits * repetition_penalties)
- # We follow the definition in OpenAI API.
- # Refer to https://platform.openai.com/docs/api-reference/parameter-details
- logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
- logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
- return logits
- def _apply_token_bans(logits: torch.Tensor,
- banned_tokens: List[List[int]]) -> torch.Tensor:
- for i, banned_token_ids in enumerate(banned_tokens):
- if not banned_token_ids:
- continue
- logits[i, banned_token_ids] = -float("inf")
- return logits
- def _apply_alphabet_soup(
- logits: torch.Tensor,
- p: torch.Tensor,
- k: torch.Tensor,
- a: torch.Tensor,
- m: torch.Tensor,
- ) -> torch.Tensor:
- logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
- # Apply top-p, min-p and top-a.
- probs_sort = logits_sort.softmax(dim=-1)
- probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
- min_p_thresholds = probs_sort[:, 0] * m
- top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
- threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
- mask = (probs_sort < threshold.unsqueeze(1)
- ) # Cull logits below the top-a threshold
- mask.logical_or_(
- probs_sum >
- p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
- mask[:, 0] = False # Guarantee at least one token is pickable
- logits_sort[mask] = -float("inf")
- # Apply top-k.
- # Create a mask for the top-k elements.
- top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
- top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
- top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
- # Final mask.
- mask = (mask | top_k_mask)
- logits_sort.masked_fill_(mask, -float("inf"))
- # Re-sort the probabilities.
- src = torch.arange(logits_idx.shape[-1],
- device=logits_idx.device).expand_as(logits_idx)
- logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
- index=logits_idx,
- src=src)
- logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
- return logits
- def _apply_tfs(
- logits: torch.Tensor,
- tfs: torch.Tensor,
- ) -> torch.Tensor:
- logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
- d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
- normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
- curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
- tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
- tfs_mask = torch.cat(
- (
- torch.zeros(
- logits.shape[0], 1, dtype=torch.bool, device=logits.device),
- tfs_mask,
- torch.ones(
- logits.shape[0], 1, dtype=torch.bool, device=logits.device),
- ),
- dim=-1,
- )
- logits_sort[tfs_mask] = -float("inf")
- logits = torch.gather(logits_sort,
- dim=-1,
- index=torch.argsort(logits_idx, dim=-1))
- return logits
- def _apply_eta_cutoff(
- logits: torch.Tensor,
- eta_cutoff: torch.Tensor,
- ) -> torch.Tensor:
- eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
- device=logits.device) * 1e-4
- shifted_logits = torch.log_softmax(logits, dim=-1)
- probs = shifted_logits.exp()
- neg_entropy = (probs * shifted_logits).nansum(dim=-1)
- eps = torch.min(eta,
- torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
- eta_mask = probs < eps
- if torch.all(eta_mask): # guard against nulling out all the logits
- topk_prob, _ = torch.max(probs, dim=-1)
- eta_mask = probs < topk_prob
- logits[eta_mask] = -float("inf")
- return logits
- def _apply_epsilon_cutoff(
- logits: torch.Tensor,
- epsilon_cutoff: torch.Tensor,
- ) -> torch.Tensor:
- eps = torch.tensor(epsilon_cutoff,
- dtype=logits.dtype,
- device=logits.device).unsqueeze(dim=1)
- probs = logits.softmax(dim=-1)
- eps_mask = probs < (eps * 1e-4)
- if torch.all(eps_mask): # guard against nulling out all the logits
- topk_prob, _ = torch.max(probs, dim=-1)
- eps_mask = probs < topk_prob
- logits[eps_mask] = -float("inf")
- return logits
- def _apply_typical_sampling(
- logits: torch.Tensor,
- typical_p: torch.Tensor,
- typical_p_sigma: torch.Tensor,
- ) -> torch.Tensor:
- typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype)
- typ_threshold = typical_p_sigma.clone().detach().to(logits.device).to(
- logits.dtype)
- THRESHOLD = 1000
- shifted_logits = torch.log_softmax(logits, dim=-1)
- probs = torch.exp(shifted_logits)
- neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
- # NOTE: We don't take the absolute value of the surprisal deviations
- # This deviates from the original implementation
- surprisal_deviations = neg_entropy - shifted_logits
- _, indices = torch.sort(surprisal_deviations)
- reordered_probs = probs.gather(-1, indices)
- typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
- # Invert the minimum deviation from the expected information content of the
- # probability distribution for the next token and scale it based on the
- # provided typical_p_sigma parameter.
- max_threshold = surprisal_deviations.min().negative() * typ_threshold
- # Mask negative deviations and positive deviations above the max threshold
- max_threshold = max_threshold.unsqueeze(1)
- surprisal_deviations[surprisal_deviations <= 0] = THRESHOLD
- surprisal_deviations[surprisal_deviations > max_threshold] = THRESHOLD
- positive_mask = surprisal_deviations == THRESHOLD
- typ_mask_sorted[..., :1] = 0
- typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
- # Merging the mask created above with the one from the standard Typical-p.
- # Masked out tokens in the distribution are True
- typ_mask = typ_mask.bitwise_and(positive_mask)
- logits[typ_mask] = -float("inf")
- return logits
- # pulls double duty for temperature and dynatemp
- def _apply_temperature(
- logits: torch.Tensor,
- temperatures: torch.Tensor,
- dynatemp_mins: torch.Tensor,
- dynatemp_maxs: torch.Tensor,
- dynatemp_exps: torch.Tensor,
- ) -> torch.Tensor:
- dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
- dynatemp_mins = dynatemp_mins[dynatemp_mask]
- dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
- dynatemp_exps = dynatemp_exps[dynatemp_mask]
- dynatemp_mins = dynatemp_mins.clamp_(min=0)
- dynatemp_logits = logits[dynatemp_mask]
- dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
- dynatemp_probs = dynatemp_shifted_logits.exp()
- dynatemp_entropies = -(dynatemp_probs *
- dynatemp_shifted_logits).nansum(dim=-1)
- dynatemp_max_entropies = torch.log_(
- (dynatemp_logits > float("-inf")).sum(dim=-1).float())
- normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
- dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
- normalized_entropies.pow_(dynatemp_exps))
- temperatures[dynatemp_mask] = dyn_temp
- temperatures[temperatures == 0.0] = 1.0
- logits.div_(temperatures.unsqueeze_(dim=1))
- return logits
- def _apply_quadratic_sampling(
- logits: torch.Tensor,
- smoothing_factors: torch.Tensor,
- smoothing_curves: torch.Tensor,
- ) -> torch.Tensor:
- """
- Applies a quadratic transformation to the logits based on the
- provided smoothing factors and curves. The transformation is
- centered around the maximum logit value in the batch.
- The transformation involves a quadratic and cubic term, with the
- cubic term controlled by the smoothing curve. The quadratic term is
- scaled by the smoothing factor, and the cubic term is scaled by the
- product of the smoothing factor and the smoothing curve.
- params:
- logits (torch.Tensor): The logits to be transformed.
- smoothing_factors (torch.Tensor): The factors to scale the quadratic
- term in the transformation.
- smoothing_curves (torch.Tensor): The factors to scale the cubic term
- in the transformation.
- returns:
- torch.Tensor: The transformed logits.
- Credits: @kalomaze
- """
- max_logits = logits.max(dim=-1, keepdim=True).values
- diff = logits - max_logits
- smoothing_factors.unsqueeze_(dim=1)
- smoothing_curves.unsqueeze_(dim=1)
- k = (3 - smoothing_curves) / 2
- s = (smoothing_curves - 1) / 2
- mask = smoothing_factors > 0
- mask = mask.flatten()
- transformed_logits = torch.where(
- logits != float('-inf'), -(k * smoothing_factors * diff**2) +
- (s * smoothing_factors * diff**3) + max_logits, logits)
- logits[mask, :] = transformed_logits[mask, :]
- return logits
- def _greedy_sample(
- selected_seq_groups: List[Tuple[List[int], SamplingParams]],
- samples: torch.Tensor,
- ) -> List[Tuple[List[int], List[int]]]:
- samples = samples.tolist()
- sample_idx = 0
- results = []
- for seq_group in selected_seq_groups:
- seq_ids, _ = seq_group
- num_parent_seqs = len(seq_ids)
- assert num_parent_seqs == 1, (
- "Greedy sampling should have only one seq.")
- parent_ids = list(range(num_parent_seqs))
- next_token_ids = [samples[sample_idx]]
- results.append((next_token_ids, parent_ids))
- sample_idx += num_parent_seqs
- return results
- def _random_sample(
- selected_seq_groups: List[Tuple[List[int], SamplingParams]],
- is_prompts: List[bool],
- random_samples: torch.Tensor,
- ) -> List[Tuple[List[int], List[int]]]:
- # Find the maximum best_of value of the prompt phase requests.
- random_samples = random_samples.cpu()
- sample_idx = 0
- results = []
- for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
- seq_ids, sampling_params = seq_group
- num_parent_seqs = len(seq_ids)
- if is_prompt:
- # Prompt phase.
- parent_ids = [0] * sampling_params.best_of
- next_token_ids = random_samples[
- sample_idx, :sampling_params.best_of].tolist()
- else:
- # Generation phase.
- parent_ids = list(range(num_parent_seqs))
- next_token_ids = random_samples[sample_idx:sample_idx +
- num_parent_seqs, 0].tolist()
- results.append((next_token_ids, parent_ids))
- sample_idx += num_parent_seqs
- return results
- def _beam_search_sample(
- selected_seq_groups: List[Tuple[List[int], SamplingParams]],
- is_prompts: List[bool],
- seq_data: Dict[int, SequenceData],
- logprobs: torch.Tensor,
- ) -> List[Tuple[List[int], List[int]]]:
- # We sample 2 * beam_width candidates to make sure that with high
- # probability we can get `beam_width` candidates in addition to
- # the finished sequences for the next iteration. See
- # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
- # for details. See also HF reference:
- # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
- #
- # Note: Beam search is not vectorized, so its speed can be slower than
- # other sampling methods.
- sample_idx = 0
- results = []
- for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
- seq_ids, sampling_params = seq_group
- num_parent_seqs = len(seq_ids)
- beam_width = sampling_params.best_of
- seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
- if is_prompt:
- # Prompt phase.
- assert num_parent_seqs == 1, (
- "Prompt input should have only one seq.")
- parent_ids = [0] * (2 * beam_width)
- _, next_token_ids = torch.topk(seq_group_logprobs[0],
- 2 * beam_width)
- next_token_ids = next_token_ids.tolist()
- else:
- # Generation phase.
- cumulative_logprobs = [
- seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
- ]
- cumulative_logprobs = torch.tensor(
- cumulative_logprobs,
- dtype=torch.float,
- device=seq_group_logprobs.device)
- seq_group_logprobs = (seq_group_logprobs +
- cumulative_logprobs.unsqueeze(dim=1))
- _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
- 2 * beam_width)
- topk_ids = topk_ids.tolist()
- vocab_size = seq_group_logprobs.size(-1)
- parent_ids = [i // vocab_size for i in topk_ids]
- next_token_ids = [i % vocab_size for i in topk_ids]
- results.append((next_token_ids, parent_ids))
- sample_idx += num_parent_seqs
- assert sample_idx == logprobs.size(0)
- return results
- # torch.multinomial forces a GPU<->CPU sync.
- # Therefore, we use an optimized implementation instead.
- # Note that we always sample with replacement.
- # probs will be modified in place, but this is fine, as we pass
- # in a copy already.
- def _multinomial(
- probs: torch.Tensor,
- num_samples: int,
- seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
- generators: Optional[List[torch.Generator]] = None,
- ) -> torch.Tensor:
- if num_samples > 1:
- # This is equivalent to torch.repeat_interleaved (which also
- # forces a GPU<->CPU sync).
- # This allows us to do sampling with replacement by creating
- # num_samples copies of each row in the tensor, and then
- # batch sampling the resulting tensor.
- probs = probs[:, None, :].expand(probs.shape[0], num_samples,
- probs.shape[1]).contiguous().view(
- -1, probs.shape[1])
- q = torch.empty_like(probs)
- if seq_groups is None:
- q.exponential_()
- else:
- sample_idx = 0
- for (seq_ids, _), generator in zip(seq_groups, generators):
- next_sample_idx = sample_idx + len(seq_ids) * num_samples
- q[sample_idx:next_sample_idx].exponential_(generator=generator)
- sample_idx = next_sample_idx
- return probs.div_(q).argmax(dim=1).view(-1, num_samples)
- def _sample(
- probs: torch.Tensor,
- logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> List[Tuple[List[int], List[int]]]:
- categorized_seq_group_ids = {t: [] for t in SamplingType}
- categorized_sample_indices = sampling_metadata.categorized_sample_indices
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- _, sampling_params = seq_group
- sampling_type = sampling_params.sampling_type
- categorized_seq_group_ids[sampling_type].append(i)
- sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
- sample_metadata = {}
- multinomial_samples = {}
- # Counterintuitively, having two loops here is actually faster.
- # The first loop can run without waiting on GPU<->CPU sync.
- for sampling_type, sample_indices in categorized_sample_indices.items():
- if len(sample_indices) == 0:
- continue
- seq_group_ids = categorized_seq_group_ids[sampling_type]
- seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
- is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
- sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
- is_prompts, sample_indices)
- if sampling_type == SamplingType.GREEDY:
- greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
- elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- max_best_of = 1
- for seq_group, is_prompt in zip(seq_groups, is_prompts):
- if is_prompt:
- _, sampling_params = seq_group
- max_best_of = max(max_best_of, sampling_params.best_of)
- seeded_args = {} if sampling_type == SamplingType.RANDOM else {
- "seq_groups": seq_groups,
- "generators": sampling_metadata.generators,
- }
- multinomial_samples[sampling_type] = _multinomial(
- probs[sample_indices], max_best_of, **seeded_args)
- elif sampling_type == SamplingType.BEAM:
- beam_search_logprobs = logprobs[sample_indices]
- else:
- raise ValueError(f"Unsupported sampling type: {sampling_type}")
- # GPU<->CPU sync happens in the loop below.
- for sampling_type, metadata in sample_metadata.items():
- seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
- if sampling_type == SamplingType.GREEDY:
- sample_results = _greedy_sample(seq_groups, greedy_samples)
- elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- sample_results = _random_sample(seq_groups, is_prompts,
- multinomial_samples[sampling_type])
- elif sampling_type == SamplingType.BEAM:
- sample_results = _beam_search_sample(seq_groups, is_prompts,
- sampling_metadata.seq_data,
- beam_search_logprobs)
- sample_results_dict.update(zip(seq_group_ids, sample_results))
- sample_results = [
- sample_results_dict[i]
- for i in range(len(sampling_metadata.seq_groups))
- ]
- return sample_results
- def _get_logprobs(
- logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- sample_results: List[Tuple[List[int], List[int]]],
- ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
- int, float]]]]:
- # Prepare query indices
- batched_logprobs_query_seq_indices: List[int] = []
- batched_logprobs_query_token_indices: List[int] = []
- largest_num_logprobs = 0
- sample_idx = 0
- for i, (seq_group, sample_result) in enumerate(
- zip(sampling_metadata.seq_groups, sample_results)):
- seq_ids, sampling_params = seq_group
- next_token_ids, parent_ids = sample_result
- num_parent_seqs = len(seq_ids)
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- largest_num_logprobs = max(largest_num_logprobs,
- sampling_params.prompt_logprobs)
- prompt_len = sampling_metadata.prompt_lens[i]
- prompt_tokens = sampling_metadata.seq_data[
- seq_ids[0]].prompt_token_ids
- batched_logprobs_query_seq_indices.extend(
- sample_idx + j for j in range(prompt_len - 1))
- batched_logprobs_query_token_indices.extend(
- token_id for token_id in prompt_tokens[1:])
- sample_idx += prompt_len - 1
- batched_logprobs_query_seq_indices.extend(
- [sample_idx + parent_id for parent_id in parent_ids])
- batched_logprobs_query_token_indices.extend(next_token_ids)
- if sampling_params.logprobs is not None:
- largest_num_logprobs = max(largest_num_logprobs,
- sampling_params.logprobs)
- sample_idx += num_parent_seqs
- assert sample_idx == logprobs.size(0)
- # Batched query for logprobs of selected token
- batched_logprobs_query_result = logprobs[[
- batched_logprobs_query_seq_indices,
- batched_logprobs_query_token_indices
- ]]
- # Batched query for logprobs of topk tokens
- if largest_num_logprobs > 0:
- top_logprobs, top_token_ids = torch.topk(logprobs,
- largest_num_logprobs,
- dim=-1)
- top_logprobs = top_logprobs.cpu()
- top_token_ids = top_token_ids.cpu()
- else:
- top_logprobs, top_token_ids = None, None
- batched_logprobs_query_result = batched_logprobs_query_result.cpu()
- # Gather results
- result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
- result_sample_logprobs: List[SampleLogprobs] = []
- sample_idx = 0
- query_result_idx = 0
- for i, (seq_group, sample_result) in enumerate(
- zip(sampling_metadata.seq_groups, sample_results)):
- seq_ids, sampling_params = seq_group
- next_token_ids, parent_ids = sample_result
- # Prompt logprobs
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- num_logprobs = sampling_params.prompt_logprobs
- prompt_len = sampling_metadata.prompt_lens[i]
- prompt_tokens = sampling_metadata.seq_data[
- seq_ids[0]].prompt_token_ids
- group_prompt_logprobs: PromptLogprobs = [None]
- for token_id in prompt_tokens[1:]:
- prompt_logprobs_dict = {
- token_id:
- batched_logprobs_query_result[query_result_idx].item()
- }
- if num_logprobs > 0:
- prompt_logprobs_dict.update(
- zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
- top_logprobs[sample_idx, :num_logprobs].tolist()))
- group_prompt_logprobs.append({
- token_id: Logprob(logprob)
- for token_id, logprob in prompt_logprobs_dict.items()
- })
- sample_idx += 1
- query_result_idx += 1
- result_prompt_logprobs.append(group_prompt_logprobs)
- else:
- result_prompt_logprobs.append(None)
- # Sample logprobs
- num_logprobs = sampling_params.logprobs
- if num_logprobs is None:
- num_logprobs = 0
- group_sample_logprobs: SampleLogprobs = []
- for next_token_id, parent_id in zip(next_token_ids, parent_ids):
- sample_logprobs_dict = {
- next_token_id:
- batched_logprobs_query_result[query_result_idx].item()
- }
- query_result_idx += 1
- if num_logprobs > 0:
- sample_logprobs_dict.update(
- zip(
- top_token_ids[sample_idx +
- parent_id, :num_logprobs].tolist(),
- top_logprobs[sample_idx +
- parent_id, :num_logprobs].tolist()))
- group_sample_logprobs.append({
- token_id: Logprob(logprob)
- for token_id, logprob in sample_logprobs_dict.items()
- })
- result_sample_logprobs.append(group_sample_logprobs)
- sample_idx += len(seq_ids)
- return result_prompt_logprobs, result_sample_logprobs
- def _build_sampler_output(
- sample_results: List[Tuple[List[int], List[int]]],
- sampling_metadata: SamplingMetadata,
- prompt_logprobs: List[Optional[PromptLogprobs]],
- sample_logprobs: List[SampleLogprobs],
- output_metadata: OutputMetadata,
- ) -> SamplerOutput:
- sampler_output = []
- for (seq_group, sample_result, group_prompt_logprobs,
- group_sample_logprobs) in zip(sampling_metadata.seq_groups,
- sample_results, prompt_logprobs,
- sample_logprobs):
- seq_ids, _ = seq_group
- next_token_ids, parent_ids = sample_result
- seq_outputs = []
- for parent_id, next_token_id, logprobs in zip(parent_ids,
- next_token_ids,
- group_sample_logprobs):
- seq_outputs.append(
- SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
- output_metadata.get(seq_ids[parent_id])))
- sampler_output.append(
- SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
- return sampler_output
- def _miro_store_args(seqids: List[int], mus: List[float],
- output_metadata: OutputMetadata) -> None:
- for sid, mu in zip(seqids,
- mus.tolist()): # tolist might be premature optimization
- output_metadata.add(sid, "miro_mu", mu)
- def _apply_mirostat_v2(
- logits: torch.Tensor,
- taus: torch.Tensor, # AKA the targeted surprise
- etas: torch.Tensor, # AKA the learning rate
- mus: torch.
- Tensor, # AKA the accumulator that always tries to approach [tau]
- ) -> torch.Tensor:
- logit_surprise = torch.softmax(
- logits, dim=-1).log2_().neg_() # Calculate surprise value per token
- # For compatibility with ooba/kobold, done in unit of bits(log base 2)
- # not nats(ln).
- # Ideally this would be a log_softmax, for numerical stability and
- # elegance purposes.
- # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
- miro_mask = logit_surprise > mus.unsqueeze(
- dim=-1) # Mask out "too-surprising" tokens (above mu)
- mininds = torch.argmin(logit_surprise, dim=-1)
- miro_mask.scatter_(
- 1, mininds.unsqueeze(dim=-1), False
- ) # Force at least one outcome to be possible, ideally the most likely one
- logits[miro_mask] = -float("inf")
- probs = torch.softmax(logits, dim=-1,
- dtype=logits.dtype) # Get probs, post-mask
- # NOTE: Mirostat updates its `mu` values based on the sample chosen.
- # The silly approach here is to just sample it and make the logits one-hot.
- # This breaks fine grained seeding, but we don't have that yet.
- # TODO: FIX when it gets added
- next_token_ids = _multinomial(probs, num_samples=1)
- # Calculation new `mu` values
- # NOTE: If we can know the logit values of the PREVIOUS iteration,
- # it should be possible to update `mu` before applying mirostat each
- # iteration, thus letting us keep _sample as the last thing that happens.
- picked_surprises = torch.gather(logit_surprise,
- dim=-1,
- index=next_token_ids)
- eps = picked_surprises.squeeze() - taus
- mus.sub_(etas * eps)
- logits.fill_(-float("inf"))
- # This value doesn't actually matter, so long as it's not -inf.
- # Vectors are now one-hot, after all.
- logits.scatter_(1, next_token_ids, 1.0)
- return logits
- def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
- output_metadata: OutputMetadata) -> torch.Tensor:
- idx = sampling_tensors.miro_indices
- seqids = sampling_tensors.miro_seqids
- taus = sampling_tensors.miro_taus
- etas = sampling_tensors.miro_etas
- mus = sampling_tensors.miro_mus
- logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
- mus) # mus is an i/o param, :vomit:
- _miro_store_args(seqids, mus, output_metadata)
- return logits
|