123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- """A layer that samples the next tokens from the model's outputs."""
- from typing import Dict, List, Tuple, Optional
- import numpy as np
- import torch
- import torch.nn as nn
- from aphrodite.modeling.metadata import InputMetadata
- from aphrodite.modeling.megatron.tensor_parallel import (
- gather_from_tensor_model_parallel_region)
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import SamplerOutput, SequenceOutputs
- _SAMPLING_EPS = 1e-5
- class Sampler(nn.Module):
- """Samples the next tokens from the model's outputs.
- This layer does the following:
- 1. Discard the hidden states that are not used for sampling (i.e., all
- tokens except the final one in each prompt).
- 2. Compute the logits for the next tokens.
- 3. Apply presence and frequency penalties.
- 4. Apply temperature scaling.
- 5. Apply top-p and top-k truncation.
- 6. Sample the next tokens.
- Here, each sequence group within the batch can have different sampling
- parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
- """
- def __init__(self, vocab_size: int) -> None:
- super().__init__()
- self.vocab_size = vocab_size
- def forward(
- self,
- embedding: torch.Tensor,
- hidden_states: torch.Tensor,
- input_metadata: InputMetadata,
- embedding_bias: Optional[torch.Tensor] = None,
- ) -> SamplerOutput:
- # Get the hidden states that we use for sampling.
- hidden_states = _prune_hidden_states(hidden_states, input_metadata)
- # Get the logits for the next tokens.
- logits = torch.matmul(hidden_states, embedding.t())
- if embedding_bias is not None:
- logits += embedding_bias
- logits = gather_from_tensor_model_parallel_region(logits)
- # Remove paddings in vocab (if any).
- logits = logits[:, :self.vocab_size]
- # Apply presence and frequency penalties.
- output_tokens = _get_output_tokens(input_metadata)
- assert len(output_tokens) == logits.shape[0]
- presence_penalties, frequency_penalties = _get_penalties(
- input_metadata)
- assert len(presence_penalties) == logits.shape[0]
- assert len(frequency_penalties) == logits.shape[0]
- logits = _apply_penalties(logits, output_tokens, presence_penalties,
- frequency_penalties, self.vocab_size)
- logits = _apply_logits_processors(input_metadata, logits, output_tokens)
- # Apply temperature scaling.
- temperatures = _get_temperatures(input_metadata)
- assert len(temperatures) == logits.shape[0]
- if any(t != 1.0 for t in temperatures):
- t = torch.tensor(temperatures,
- dtype=logits.dtype,
- device=logits.device)
- # Use in-place division to avoid creating a new tensor.
- logits.div_(t.unsqueeze(dim=1))
- # Apply top-p and top-k truncation.
- top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
- assert len(top_ps) == len(top_ks) == logits.shape[0]
- do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
- do_top_k = any(k != self.vocab_size for k in top_ks)
- if do_top_p or do_top_k:
- logits = _apply_top_p_top_k(logits, top_ps, top_ks)
- # We use float32 for probabilities and log probabilities.
- # Compute the probabilities.
- probs = torch.softmax(logits, dim=-1, dtype=torch.float)
- # Compute the log probabilities (before applying top-p and top-k).
- # Use log_softmax to ensure numerical stability
- logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
- # Sample the next tokens.
- return _sample(probs, logprobs, input_metadata)
- def _prune_hidden_states(
- hidden_states: torch.Tensor,
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- start_idx = 0
- last_token_indicies: List[int] = []
- for prompt_len in input_metadata.prompt_lens:
- last_token_indicies.append(start_idx + prompt_len - 1)
- start_idx += prompt_len
- last_token_indicies.extend(
- range(start_idx, start_idx + input_metadata.num_generation_tokens))
- return hidden_states.index_select(
- 0, torch.tensor(last_token_indicies, device=hidden_states.device))
- def _get_penalties(
- input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
- # Collect the presence and frequency penalties.
- presence_penalties: List[float] = []
- frequency_penalties: List[float] = []
- for i, seq_group in enumerate(input_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
- p = sampling_params.presence_penalty
- f = sampling_params.frequency_penalty
- if i < input_metadata.num_prompts:
- # A prompt input.
- presence_penalties.append(p)
- frequency_penalties.append(f)
- else:
- # A generation token.
- presence_penalties += [p] * len(seq_ids)
- frequency_penalties += [f] * len(seq_ids)
- return presence_penalties, frequency_penalties
- def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
- output_tokens: List[List[int]] = []
- for i, seq_group in enumerate(input_metadata.seq_groups):
- seq_ids, _ = seq_group
- if i < input_metadata.num_prompts:
- # A prompt input.
- # NOTE: While the prompt input usually has no output tokens,
- # it may have output tokens in the case of recomputation.
- seq_id = seq_ids[0]
- seq_data = input_metadata.seq_data[seq_id]
- output_tokens.append(seq_data.output_token_ids)
- else:
- # A generation token.
- for seq_id in seq_ids:
- seq_data = input_metadata.seq_data[seq_id]
- output_tokens.append(seq_data.output_token_ids)
- return output_tokens
- def _apply_logits_processors(
- input_metadata: InputMetadata,
- logits: torch.Tensor,
- output_tokens: List[List[int]]
- ) -> torch.Tensor:
- for _, seq_group in enumerate(input_metadata.seq_groups):
- _, sampling_params = seq_group
- logits_processors = sampling_params.logits_processors
- if logits_processors is not None:
- for logits_processor in logits_processors:
- logits = logits_processor(logits, output_tokens)
- return logits
- def _apply_penalties(
- logits: torch.Tensor,
- output_tokens: List[List[int]],
- presence_penalties: List[float],
- frequency_penalties: List[float],
- vocab_size: int,
- ) -> torch.Tensor:
- num_seqs = logits.shape[0]
- # Collect the indices of sequences that have non-zero penalties.
- indices = []
- for i in range(num_seqs):
- if not output_tokens[i]:
- continue
- p = presence_penalties[i]
- f = frequency_penalties[i]
- if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
- continue
- indices.append(i)
- # Return early if all sequences have zero penalties.
- if not indices:
- return logits
- bin_counts = []
- for i in indices:
- bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
- bin_counts = np.stack(bin_counts, axis=0)
- bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
- device=logits.device)
- frequency_penalties = [frequency_penalties[i] for i in indices]
- frequency_penalties = torch.tensor(frequency_penalties,
- dtype=logits.dtype,
- device=logits.device)
- presence_penalties = [presence_penalties[i] for i in indices]
- presence_penalties = torch.tensor(presence_penalties,
- dtype=logits.dtype,
- device=logits.device)
- # We follow the definition in OpenAI API.
- # Refer to https://platform.openai.com/docs/api-reference/parameter-details
- logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
- presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
- logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
- return logits
- def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
- # Collect the temperatures for the logits.
- temperatures: List[float] = []
- for i, seq_group in enumerate(input_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
- temperature = sampling_params.temperature
- if temperature < _SAMPLING_EPS:
- # NOTE: Zero temperature means deterministic sampling
- # (i.e., greedy sampling or beam search).
- # Set the temperature to 1 to avoid division by zero.
- temperature = 1.0
- if i < input_metadata.num_prompts:
- # A prompt input.
- temperatures.append(temperature)
- else:
- # A generation token.
- temperatures += [temperature] * len(seq_ids)
- return temperatures
- def _get_top_p_top_k(
- input_metadata: InputMetadata,
- vocab_size: int,
- ) -> Tuple[List[float], List[int]]:
- top_ps: List[float] = []
- top_ks: List[int] = []
- for i, seq_group in enumerate(input_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
- top_p = sampling_params.top_p
- # k should not be greater than the vocab size.
- top_k = min(sampling_params.top_k, vocab_size)
- # k=-1 means no truncation.
- top_k = vocab_size if top_k == -1 else top_k
- if i < input_metadata.num_prompts:
- # A prompt input.
- top_ps.append(top_p)
- top_ks.append(top_k)
- else:
- # A generation token.
- top_ps += [top_p] * len(seq_ids)
- top_ks += [top_k] * len(seq_ids)
- return top_ps, top_ks
- def _apply_top_p_top_k(
- logits: torch.Tensor,
- top_ps: List[float],
- top_ks: List[int],
- ) -> torch.Tensor:
- p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
- k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
- logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
- # Apply top-p.
- probs_sort = logits_sort.softmax(dim=-1)
- probs_sum = probs_sort.cumsum(dim=-1)
- top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
- logits_sort[top_p_mask] = -float("inf")
- # Apply top-k.
- # Create a mask for the top-k elements.
- top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
- top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
- top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
- logits_sort[top_k_mask] = -float("inf")
- # Re-sort the probabilities.
- logits = torch.gather(logits_sort,
- dim=-1,
- index=torch.argsort(logits_idx, dim=-1))
- return logits
- def _get_topk_logprobs(
- logprobs: torch.Tensor,
- num_logprobs: Optional[int],
- ) -> Dict[int, float]:
- if num_logprobs is None or num_logprobs == 0:
- return {}
- topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
- if num_logprobs == 1:
- topk_logprobs = [topk_logprobs.item()]
- topk_ids = [topk_ids.item()]
- else:
- topk_logprobs = topk_logprobs.tolist()
- topk_ids = topk_ids.tolist()
- token_to_logprob: Dict[int, float] = {}
- for token_id, logprob in zip(topk_ids, topk_logprobs):
- token_to_logprob[token_id] = logprob
- return token_to_logprob
- def _sample_from_prompt(
- prob: torch.Tensor,
- sampling_params: SamplingParams,
- ) -> List[int]:
- if sampling_params.use_beam_search:
- # Beam search.
- beam_width = sampling_params.best_of
- # Sample 2 * beam_width candidates to make sure that with high
- # probability we can get `beam_width` candidates in addition to
- # the finished sequences for the next iteration. See
- # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
- # for details. See also HF reference:
- # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
- _, next_token_ids = torch.topk(prob, 2 * beam_width)
- next_token_ids = next_token_ids.tolist()
- elif sampling_params.temperature < _SAMPLING_EPS:
- # Greedy sampling.
- assert sampling_params.best_of == 1
- next_token_id = torch.argmax(prob)
- next_token_ids = [next_token_id.item()]
- else:
- # Random sampling.
- # Sample `best_of` tokens for the prompt.
- num_seqs = sampling_params.best_of
- next_token_ids = torch.multinomial(prob,
- num_samples=num_seqs,
- replacement=True)
- next_token_ids = next_token_ids.tolist()
- return next_token_ids
- def _sample_from_generation_tokens(
- seq_ids: List[int],
- probs: torch.Tensor,
- logprobs: torch.Tensor,
- seq_logprobs: List[float],
- sampling_params: SamplingParams,
- ) -> Tuple[List[int], List[int]]:
- # NOTE: sampling_params.best_of can be greater than
- # len(seq_ids) because some sequences in the group might have
- # been already terminated.
- if sampling_params.use_beam_search:
- # Beam search.
- # Add cumulative logprobs for the sequences in the group.
- seq_logprobs = torch.tensor(seq_logprobs,
- dtype=torch.float,
- device=logprobs.device)
- logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
- vocab_size = logprobs.size(-1)
- beam_width = len(seq_ids)
- _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
- topk_ids = topk_ids.tolist()
- seq_idx = [i // vocab_size for i in topk_ids]
- parent_seq_ids = [seq_ids[i] for i in seq_idx]
- next_token_ids = [i % vocab_size for i in topk_ids]
- elif sampling_params.temperature < _SAMPLING_EPS:
- # Greedy sampling.
- assert len(seq_ids) == 1
- next_token_id = torch.argmax(probs, dim=-1)
- next_token_ids = [int(next_token_id.item())]
- parent_seq_ids = seq_ids
- else:
- # Random sampling.
- # Sample 1 token for each sequence in the group.
- next_token_ids = torch.multinomial(probs,
- num_samples=1,
- replacement=True)
- next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
- parent_seq_ids = seq_ids
- return parent_seq_ids, next_token_ids
- def _sample(
- probs: torch.Tensor,
- logprobs: torch.Tensor,
- input_metadata: InputMetadata,
- ) -> SamplerOutput:
- seq_outputs: SamplerOutput = []
- # TODO: Optimize.
- idx = 0
- for i, seq_group in enumerate(input_metadata.seq_groups):
- seq_group_outputs: List[SequenceOutputs] = []
- seq_ids, sampling_params = seq_group
- if i < input_metadata.num_prompts:
- # Generate the next tokens for a prompt input.
- assert len(seq_ids) == 1, "Prompt input should have only one seq."
- parent_seq_id = seq_ids[0]
- prob = probs[idx]
- logprob = logprobs[idx]
- idx += 1
- # Sample the next tokens.
- next_token_ids = _sample_from_prompt(prob, sampling_params)
- # Get top-k log probabilities for the next tokens.
- next_logprobs = _get_topk_logprobs(logprob,
- sampling_params.logprobs)
- # Build the output.
- for next_token_id in next_token_ids:
- output_logprobs = next_logprobs.copy()
- output_logprobs[next_token_id] = logprob[next_token_id].item()
- seq_group_outputs.append(
- SequenceOutputs(parent_seq_id, next_token_id,
- output_logprobs))
- else:
- # Generate the next tokens for generation tokens.
- num_parent_seqs = len(seq_ids)
- prob = probs[idx:idx + num_parent_seqs]
- logprob = logprobs[idx:idx + num_parent_seqs]
- idx += num_parent_seqs
- # Sample the next tokens.
- seq_logprobs = [
- input_metadata.seq_data[seq_id].cumulative_logprob
- for seq_id in seq_ids
- ]
- parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
- seq_ids, prob, logprob, seq_logprobs, sampling_params)
- # Get top-k log probabilities for the next tokens.
- next_logprobs: Dict[int, Dict[int, float]] = {}
- for j, seq_id in enumerate(seq_ids):
- next_logprobs[seq_id] = _get_topk_logprobs(
- logprob[j], sampling_params.logprobs)
- # Build the output.
- for parent_seq_id, next_token_id in zip(parent_seq_ids,
- next_token_ids):
- j = seq_ids.index(parent_seq_id)
- output_logprobs = next_logprobs[parent_seq_id].copy()
- output_logprobs[next_token_id] = logprob[j,
- next_token_id].item()
- seq_group_outputs.append(
- SequenceOutputs(parent_seq_id, next_token_id,
- output_logprobs))
- seq_outputs.append(seq_group_outputs)
- return seq_outputs
|