typical_acceptance_sampler.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import torch
  2. import torch.jit
  3. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  4. SpecDecodeDeterministicBaseSampler)
  5. class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
  6. """Apply typical acceptance sampling as described in section 3.3.2 in
  7. "MEDUSA: Simple LLM Inference Acceleration Framework with
  8. Multiple Decoding Heads"
  9. https://arxiv.org/abs/2401.10774
  10. """
  11. def __init__(
  12. self,
  13. posterior_threshold: float,
  14. posterior_alpha: float,
  15. disable_bonus_tokens: bool = False,
  16. strict_mode: bool = False,
  17. ):
  18. """Create a Typical Acceptance Sampler.
  19. Args:
  20. disable_bonus_tokens: Whether or not to disable the bonus token.
  21. Require when bonus tokens will cause corrupt KV cache for
  22. proposal methods that require KV cache.
  23. strict_mode: Whether or not to perform shape/device/dtype checks
  24. during sampling. This catches correctness issues but adds
  25. nontrivial latency.
  26. posterior_threshold : A threshold value that sets a lower bound
  27. on the posterior probability of a token in target model for it
  28. to be accepted.
  29. posterior_alpha : A scaling factor for the entropy-based
  30. threshold in typical acceptance sampling.
  31. """
  32. self._posterior_threshold = posterior_threshold
  33. self._posterior_alpha = posterior_alpha
  34. super().__init__(disable_bonus_tokens=disable_bonus_tokens,
  35. strict_mode=strict_mode)
  36. def forward(
  37. self,
  38. target_probs: torch.Tensor,
  39. bonus_token_ids: torch.Tensor,
  40. draft_probs: torch.Tensor,
  41. draft_token_ids: torch.Tensor,
  42. ) -> torch.Tensor:
  43. """Sample token ids using typical acceptance sampling. This accepts
  44. or rejects tokens proposed by the draft model using the probability
  45. of each token according to the draft and target models.
  46. In the worst case where all draft tokens are rejected, it is guaranteed
  47. one token will be emitted.
  48. In the case where all draft tokens are accepted, the bonus token will be
  49. accepted conditioned on self._disable_bonus_tokens being false.
  50. Args:
  51. target_probs: The probability distribution over token ids given
  52. context according to the target model.
  53. shape = [batch_size, num_speculative_tokens, vocab_size]
  54. bonus_token_ids: The "bonus" token ids that are accepted iff all
  55. speculative tokens in a sequence are accepted.
  56. shape = [batch_size, num_bonus_tokens]
  57. draft_probs: This parameter is unused by the acceptance sampler.
  58. draft_token_ids: The token ids that were sampled from the draft
  59. probabilities.
  60. shape = [batch_size, num_speculative_tokens]
  61. Returns:
  62. output_token_ids: The token ids sampled via rejection sampling,
  63. or -1 if unable to sample a token because the previous token
  64. was rejected.
  65. shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
  66. """
  67. # Only perform shape/dtype/device checking in strict mode, as it adds
  68. # overhead.
  69. if self._strict_mode:
  70. self._raise_if_incorrect_input(target_probs, draft_token_ids,
  71. bonus_token_ids)
  72. accepted = self._evaluate_accepted_tokens(target_probs,
  73. draft_token_ids)
  74. recovered_token_ids = self._replacement_token_ids(target_probs)
  75. output_token_ids = self._create_output(accepted, recovered_token_ids,
  76. draft_token_ids,
  77. bonus_token_ids)
  78. return output_token_ids
  79. def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
  80. r"""
  81. Evaluates and returns a mask of accepted tokens based on the
  82. posterior probabilities.
  83. Parameters:
  84. ----------
  85. target_probs : torch.Tensor
  86. A tensor of shape (batch_size, k, vocab_size) representing
  87. the probabilities of each token in the vocabulary for each
  88. position in the proposed sequence. This is the distribution
  89. generated by the target model.
  90. draft_token_ids : torch.Tensor
  91. A tensor of shape (batch_size, k) representing the proposed
  92. token ids.
  93. A draft token_id x_{n+k} is accepted if it satisfies the
  94. following condition
  95. .. math::
  96. p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
  97. \min \left( \epsilon, \delta * \exp \left(
  98. -H(p_{\text{original}}(
  99. \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
  100. where :math:`p_{\text{original}}` corresponds to target_probs
  101. and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
  102. specified using self._posterior_threshold and self._posterior_alpha
  103. This method computes the posterior probabilities for the given
  104. draft token ids based on the provided target probabilities. It
  105. calculates the entropy of the posterior distribution and determines
  106. a dynamic threshold for each token position using the provided
  107. posterior_threshold and posterior_alpha values. The method then
  108. returns a boolean mask indicating which tokens can be accepted.
  109. Returns:
  110. -------
  111. torch.Tensor
  112. A boolean tensor of shape (batch_size, k) where each element
  113. indicates whether the corresponding draft token has been accepted
  114. or rejected. True indicates acceptance and false indicates
  115. rejection.
  116. """
  117. device = target_probs.device
  118. candidates_prob = torch.gather(
  119. target_probs, dim=-1,
  120. index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
  121. # A small constant added to prevent computing the logarithm of zero,
  122. # which can lead to undefined values.
  123. epsilon = 1e-5
  124. posterior_entropy = -torch.sum(
  125. target_probs * torch.log(target_probs + epsilon), dim=-1)
  126. threshold = torch.minimum(
  127. torch.ones_like(posterior_entropy, device=device) *
  128. self._posterior_threshold,
  129. torch.exp(-posterior_entropy) * self._posterior_alpha,
  130. )
  131. accepted_mask = candidates_prob > threshold
  132. return accepted_mask
  133. def _replacement_token_ids(self, target_probs):
  134. """
  135. Generate one replacement token ID for each sequence based on target
  136. probabilities. The replacement token is used as the fallback option
  137. if typical acceptance sampling does not accept any draft tokens for
  138. that particular sequence.
  139. This method computes the token IDs to be replaced by selecting the
  140. token with the highest probability for each sequence in the first
  141. position. The rest of the output is filled with -1.
  142. Parameters
  143. ----------
  144. target_probs : torch.Tensor
  145. A tensor of shape (batch_size, k, vocab_size) containing
  146. the target probability distribution
  147. Returns
  148. -------
  149. torch.Tensor
  150. A tensor of shape (batch_size, k) with the replacement
  151. token IDs. Only the first column is set, and the rest of the
  152. columns are filled with -1.
  153. """
  154. max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
  155. output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
  156. dtype=self.token_id_dtype,
  157. device=target_probs.device)
  158. output[:, 0] = max_indices
  159. return output