typical_acceptance_sampler.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. import torch.jit
  3. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  4. SpecDecodeDeterministicBaseSampler)
  5. class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
  6. """Apply typical acceptance sampling as described in section 3.3.1 in
  7. "MEDUSA: Simple LLM Inference Acceleration Framework with
  8. Multiple Decoding Heads"
  9. https://arxiv.org/pdf/2401.10774
  10. """
  11. def __init__(
  12. self,
  13. posterior_threshold: float,
  14. posterior_alpha: float,
  15. strict_mode: bool = False,
  16. ):
  17. """Create a Typical Acceptance Sampler.
  18. Args:
  19. strict_mode: Whether or not to perform shape/device/dtype checks
  20. during sampling. This catches correctness issues but adds
  21. nontrivial latency.
  22. posterior_threshold : A threshold value that sets a lower bound
  23. on the posterior probability of a token in target model for it
  24. to be accepted.
  25. posterior_alpha : A scaling factor for the entropy-based
  26. threshold in typical acceptance sampling.
  27. """
  28. self._posterior_threshold = posterior_threshold
  29. self._posterior_alpha = posterior_alpha
  30. super().__init__(strict_mode=strict_mode)
  31. def forward(
  32. self,
  33. target_with_bonus_probs: torch.Tensor,
  34. bonus_token_ids: torch.Tensor,
  35. draft_probs: torch.Tensor,
  36. draft_token_ids: torch.Tensor,
  37. ) -> torch.Tensor:
  38. """Sample token ids using typical acceptance sampling. This accepts
  39. or rejects tokens proposed by the draft model using the probability
  40. of each token according to the draft and target models.
  41. In the worst case where all draft tokens are rejected, it is guaranteed
  42. one token will be emitted.
  43. In the case where all draft tokens are accepted, the bonus token will be
  44. accepted.
  45. Args:
  46. target_probs: The probability distribution over token ids given
  47. context according to the target model.
  48. shape = [batch_size, num_speculative_tokens, vocab_size]
  49. bonus_token_ids: The "bonus" token ids that are accepted iff all
  50. speculative tokens in a sequence are accepted.
  51. shape = [batch_size, num_bonus_tokens]
  52. draft_probs: This parameter is unused by the acceptance sampler.
  53. draft_token_ids: The token ids that were sampled from the draft
  54. probabilities.
  55. shape = [batch_size, num_speculative_tokens]
  56. Returns:
  57. output_token_ids: The token ids sampled via rejection sampling,
  58. or -1 if unable to sample a token because the previous token
  59. was rejected.
  60. shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
  61. """
  62. # Only perform shape/dtype/device checking in strict mode, as it adds
  63. # overhead.
  64. if self._strict_mode:
  65. self._raise_if_incorrect_input(target_with_bonus_probs,
  66. draft_token_ids, bonus_token_ids)
  67. target_probs = target_with_bonus_probs[:, :-1]
  68. accepted = self._evaluate_accepted_tokens(target_probs,
  69. draft_token_ids)
  70. recovered_token_ids = self._get_recovered_token_ids(target_probs)
  71. output_token_ids = self._create_output(accepted, recovered_token_ids,
  72. draft_token_ids,
  73. bonus_token_ids)
  74. return output_token_ids
  75. def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
  76. r"""
  77. Evaluates and returns a mask of accepted tokens based on the
  78. posterior probabilities.
  79. Parameters:
  80. ----------
  81. target_probs : torch.Tensor
  82. A tensor of shape (batch_size, k, vocab_size) representing
  83. the probabilities of each token in the vocabulary for each
  84. position in the proposed sequence. This is the distribution
  85. generated by the target model.
  86. draft_token_ids : torch.Tensor
  87. A tensor of shape (batch_size, k) representing the proposed
  88. token ids.
  89. A draft token_id x_{n+k} is accepted if it satisfies the
  90. following condition
  91. .. math::
  92. p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
  93. \min \left( \epsilon, \delta * \exp \left(
  94. -H(p_{\text{original}}(
  95. \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
  96. where :math:`p_{\text{original}}` corresponds to target_probs
  97. and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
  98. specified using self._posterior_threshold and self._posterior_alpha
  99. This method computes the posterior probabilities for the given
  100. draft token ids based on the provided target probabilities. It
  101. calculates the entropy of the posterior distribution and determines
  102. a dynamic threshold for each token position using the provided
  103. posterior_threshold and posterior_alpha values. The method then
  104. returns a boolean mask indicating which tokens can be accepted.
  105. Returns:
  106. -------
  107. torch.Tensor
  108. A boolean tensor of shape (batch_size, k) where each element
  109. indicates whether the corresponding draft token has been accepted
  110. or rejected. True indicates acceptance and false indicates
  111. rejection.
  112. """
  113. device = target_probs.device
  114. candidates_prob = torch.gather(
  115. target_probs, dim=-1,
  116. index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
  117. # A small constant added to prevent computing the logarithm of zero,
  118. # which can lead to undefined values.
  119. epsilon = 1e-5
  120. posterior_entropy = -torch.sum(
  121. target_probs * torch.log(target_probs + epsilon), dim=-1)
  122. threshold = torch.minimum(
  123. torch.ones_like(posterior_entropy, device=device) *
  124. self._posterior_threshold,
  125. torch.exp(-posterior_entropy) * self._posterior_alpha,
  126. )
  127. accepted_mask = candidates_prob > threshold
  128. return accepted_mask
  129. def _get_recovered_token_ids(self, target_probs):
  130. """
  131. The recovered token ids will fill the first unmatched token
  132. by the target token.
  133. Parameters
  134. ----------
  135. target_probs : torch.Tensor
  136. A tensor of shape (batch_size, k, vocab_size) containing
  137. the target probability distribution
  138. Returns
  139. -------
  140. torch.Tensor
  141. A tensor of shape (batch_size, k) with the recovered token
  142. ids which are selected from target probs.
  143. """
  144. max_indices = torch.argmax(target_probs, dim=-1)
  145. return max_indices