123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- import torch
- import torch.jit
- from aphrodite.modeling.layers.spec_decode_base_sampler import (
- SpecDecodeDeterministicBaseSampler)
- class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
- """Apply typical acceptance sampling as described in section 3.3.2 in
- "MEDUSA: Simple LLM Inference Acceleration Framework with
- Multiple Decoding Heads"
- https://arxiv.org/abs/2401.10774
- """
- def __init__(
- self,
- posterior_threshold: float,
- posterior_alpha: float,
- disable_bonus_tokens: bool = False,
- strict_mode: bool = False,
- ):
- """Create a Typical Acceptance Sampler.
- Args:
- disable_bonus_tokens: Whether or not to disable the bonus token.
- Require when bonus tokens will cause corrupt KV cache for
- proposal methods that require KV cache.
- strict_mode: Whether or not to perform shape/device/dtype checks
- during sampling. This catches correctness issues but adds
- nontrivial latency.
- posterior_threshold : A threshold value that sets a lower bound
- on the posterior probability of a token in target model for it
- to be accepted.
- posterior_alpha : A scaling factor for the entropy-based
- threshold in typical acceptance sampling.
- """
- self._posterior_threshold = posterior_threshold
- self._posterior_alpha = posterior_alpha
- super().__init__(disable_bonus_tokens=disable_bonus_tokens,
- strict_mode=strict_mode)
- def forward(
- self,
- target_probs: torch.Tensor,
- bonus_token_ids: torch.Tensor,
- draft_probs: torch.Tensor,
- draft_token_ids: torch.Tensor,
- ) -> torch.Tensor:
- """Sample token ids using typical acceptance sampling. This accepts
- or rejects tokens proposed by the draft model using the probability
- of each token according to the draft and target models.
- In the worst case where all draft tokens are rejected, it is guaranteed
- one token will be emitted.
- In the case where all draft tokens are accepted, the bonus token will be
- accepted conditioned on self._disable_bonus_tokens being false.
- Args:
- target_probs: The probability distribution over token ids given
- context according to the target model.
- shape = [batch_size, num_speculative_tokens, vocab_size]
- bonus_token_ids: The "bonus" token ids that are accepted iff all
- speculative tokens in a sequence are accepted.
- shape = [batch_size, num_bonus_tokens]
- draft_probs: This parameter is unused by the acceptance sampler.
- draft_token_ids: The token ids that were sampled from the draft
- probabilities.
- shape = [batch_size, num_speculative_tokens]
- Returns:
- output_token_ids: The token ids sampled via rejection sampling,
- or -1 if unable to sample a token because the previous token
- was rejected.
- shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
- """
- # Only perform shape/dtype/device checking in strict mode, as it adds
- # overhead.
- if self._strict_mode:
- self._raise_if_incorrect_input(target_probs, draft_token_ids,
- bonus_token_ids)
- accepted = self._evaluate_accepted_tokens(target_probs,
- draft_token_ids)
- recovered_token_ids = self._replacement_token_ids(target_probs)
- output_token_ids = self._create_output(accepted, recovered_token_ids,
- draft_token_ids,
- bonus_token_ids)
- return output_token_ids
- def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
- r"""
- Evaluates and returns a mask of accepted tokens based on the
- posterior probabilities.
- Parameters:
- ----------
- target_probs : torch.Tensor
- A tensor of shape (batch_size, k, vocab_size) representing
- the probabilities of each token in the vocabulary for each
- position in the proposed sequence. This is the distribution
- generated by the target model.
- draft_token_ids : torch.Tensor
- A tensor of shape (batch_size, k) representing the proposed
- token ids.
- A draft token_id x_{n+k} is accepted if it satisfies the
- following condition
-
- .. math::
- p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
- \min \left( \epsilon, \delta * \exp \left(
- -H(p_{\text{original}}(
- \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
-
- where :math:`p_{\text{original}}` corresponds to target_probs
- and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
- specified using self._posterior_threshold and self._posterior_alpha
- This method computes the posterior probabilities for the given
- draft token ids based on the provided target probabilities. It
- calculates the entropy of the posterior distribution and determines
- a dynamic threshold for each token position using the provided
- posterior_threshold and posterior_alpha values. The method then
- returns a boolean mask indicating which tokens can be accepted.
- Returns:
- -------
- torch.Tensor
- A boolean tensor of shape (batch_size, k) where each element
- indicates whether the corresponding draft token has been accepted
- or rejected. True indicates acceptance and false indicates
- rejection.
-
- """
- device = target_probs.device
- candidates_prob = torch.gather(
- target_probs, dim=-1,
- index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
- # A small constant added to prevent computing the logarithm of zero,
- # which can lead to undefined values.
- epsilon = 1e-5
- posterior_entropy = -torch.sum(
- target_probs * torch.log(target_probs + epsilon), dim=-1)
- threshold = torch.minimum(
- torch.ones_like(posterior_entropy, device=device) *
- self._posterior_threshold,
- torch.exp(-posterior_entropy) * self._posterior_alpha,
- )
- accepted_mask = candidates_prob > threshold
- return accepted_mask
- def _replacement_token_ids(self, target_probs):
- """
- Generate one replacement token ID for each sequence based on target
- probabilities. The replacement token is used as the fallback option
- if typical acceptance sampling does not accept any draft tokens for
- that particular sequence.
- This method computes the token IDs to be replaced by selecting the
- token with the highest probability for each sequence in the first
- position. The rest of the output is filled with -1.
- Parameters
- ----------
- target_probs : torch.Tensor
- A tensor of shape (batch_size, k, vocab_size) containing
- the target probability distribution
- Returns
- -------
- torch.Tensor
- A tensor of shape (batch_size, k) with the replacement
- token IDs. Only the first column is set, and the rest of the
- columns are filled with -1.
- """
- max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
- output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
- dtype=self.token_id_dtype,
- device=target_probs.device)
- output[:, 0] = max_indices
- return output
|