rejection_sampler.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from functools import cached_property
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.jit
  5. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  6. SpecDecodeStochasticBaseSampler)
  7. class RejectionSampler(SpecDecodeStochasticBaseSampler):
  8. """Apply modified rejection sampling as described in "Accelerating Large
  9. Language Model Decoding with Speculative Sampling"
  10. https://arxiv.org/pdf/2302.01318.pdf.
  11. """
  12. def __init__(self,
  13. disable_bonus_tokens: bool = True,
  14. strict_mode: bool = False):
  15. """Create a rejection sampler.
  16. Args:
  17. disable_bonus_tokens: Whether or not to disable the bonus token.
  18. Require when bonus tokens will cause corrupt KV cache for
  19. proposal methods that require KV cache.
  20. strict_mode: Whether or not to perform shape/device/dtype checks
  21. during sampling. This catches correctness issues but adds
  22. nontrivial latency.
  23. """
  24. super().__init__(disable_bonus_tokens=disable_bonus_tokens,
  25. strict_mode=strict_mode)
  26. def forward(
  27. self,
  28. target_probs: torch.Tensor,
  29. bonus_token_ids: torch.Tensor,
  30. draft_probs: torch.Tensor,
  31. draft_token_ids: torch.Tensor,
  32. seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
  33. ) -> torch.Tensor:
  34. """Sample token ids using rejection sampling. This accepts or rejects
  35. tokens proposed by the draft model using the probability of each token
  36. according to the draft and target models.
  37. In the worst case where all draft tokens are rejected, it is guaranteed
  38. one correct token will be emitted.
  39. In the case where all draft tokens are accepted, a bonus token will be
  40. accepted as its cheap to have the target model score this speculative
  41. sequence.
  42. Args:
  43. target_probs: The probability distribution over token ids given
  44. context according to the target model.
  45. shape = [batch_size, num_speculative_tokens, vocab_size]
  46. bonus_token_ids: The "bonus" token ids that are accepted iff all
  47. speculative tokens in a sequence are accepted.
  48. shape = [batch_size, num_bonus_tokens]
  49. draft_probs: The probability distribution over token ids given
  50. context according to the draft model.
  51. shape = [batch_size, num_speculative_tokens, vocab_size]
  52. draft_token_ids: The token ids that were sampled from the draft
  53. probabilities.
  54. shape = [batch_size, num_speculative_tokens]
  55. seeded_seqs: Dict of batch row index to torch generator, for
  56. sequences using seeded generation.
  57. Returns:
  58. output_token_ids: The token ids sampled via rejection sampling,
  59. or -1 if unable to sample a token because the previous token
  60. was rejected.
  61. shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
  62. """
  63. # Only perform shape/dtype/device checking in strict mode, as it adds
  64. # overhead.
  65. if self._strict_mode:
  66. self._raise_if_incorrect_input(target_probs, draft_token_ids,
  67. bonus_token_ids, draft_probs)
  68. accepted, recovered_token_ids = (
  69. self._batch_modified_rejection_sampling(
  70. target_probs,
  71. draft_probs,
  72. draft_token_ids,
  73. seeded_seqs,
  74. ))
  75. output_token_ids = self._create_output(
  76. accepted,
  77. recovered_token_ids,
  78. draft_token_ids,
  79. bonus_token_ids,
  80. )
  81. return output_token_ids
  82. def _batch_modified_rejection_sampling(
  83. self,
  84. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  85. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  86. draft_token_ids: torch.Tensor, # [batch_size, k]
  87. seeded_seqs: Optional[Dict[int, torch.Generator]],
  88. ) -> Tuple[torch.Tensor, torch.Tensor]:
  89. """Perform modified rejection sampling on each sequence.
  90. Returns:
  91. A tuple of two tensors:
  92. 0: A bool tensor of which tokens in each sequence is accepted.
  93. shape = [batch_size, k]
  94. 1: Token ids sampled from a recovered distribution, to be used
  95. when a token is rejected.
  96. shape = [batch_size, k]
  97. """
  98. batch_size, k, vocab_size = draft_probs.shape
  99. # shape [batch_size, k]
  100. accepted = self._get_accepted(target_probs, draft_probs,
  101. draft_token_ids, seeded_seqs)
  102. recovered_probs = self._get_recovered_probs(
  103. target_probs, draft_probs).reshape(batch_size * k, vocab_size)
  104. # NOTE: the recovered_probs are overwritten by this method.
  105. recovered_token_ids = _multinomial(
  106. recovered_probs,
  107. num_samples=1,
  108. k=k,
  109. seeded_seqs=seeded_seqs or {},
  110. ).reshape(batch_size, k)
  111. return accepted, recovered_token_ids
  112. def _get_accepted(
  113. self,
  114. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  115. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  116. draft_token_ids: torch.Tensor, # [batch_size, k]
  117. seeded_seqs: Optional[Dict[int, torch.Generator]],
  118. ) -> torch.Tensor:
  119. r"""Create bool matrix over the proposed draft tokens. If
  120. True, then a token can be accepted, else it should be
  121. rejected.
  122. Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
  123. :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
  124. to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
  125. same conditional probability according to the draft model, the token
  126. is accepted with probability:
  127. .. math::
  128. \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
  129. {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
  130. This implementation does not apply causality. When using the output,
  131. if a token is rejected, subsequent tokens should not be used.
  132. Returns a bool tensor of shape [batch_size, k] specifying which tokens
  133. are accepted.
  134. """
  135. batch_size, k, _ = draft_probs.shape
  136. batch_indices = torch.arange(batch_size,
  137. device=target_probs.device)[:, None]
  138. probs_indicies = torch.arange(k, device=target_probs.device)
  139. # shape [batch_size, k]
  140. selected_draft_probs = draft_probs[batch_indices, probs_indicies,
  141. draft_token_ids]
  142. # shape [batch_size, k]
  143. selected_target_probs = target_probs[batch_indices, probs_indicies,
  144. draft_token_ids]
  145. if not seeded_seqs:
  146. uniform_rand = torch.rand_like(selected_target_probs)
  147. else:
  148. uniform_rand = torch.empty_like(selected_target_probs)
  149. non_seeded_indices = []
  150. for idx in range(batch_size):
  151. generator = seeded_seqs.get(idx)
  152. if generator is None:
  153. non_seeded_indices.append(idx)
  154. else:
  155. uniform_rand[idx, :] = torch.rand(
  156. 1,
  157. k,
  158. dtype=self.probs_dtype,
  159. device=target_probs.device,
  160. generator=generator)
  161. if non_seeded_indices:
  162. uniform_rand[non_seeded_indices, :] = torch.rand(
  163. len(non_seeded_indices),
  164. k,
  165. dtype=self.probs_dtype,
  166. device=target_probs.device)
  167. capped_ratio = torch.minimum(
  168. selected_target_probs / selected_draft_probs,
  169. torch.full((1, ), 1, device=target_probs.device))
  170. accepted = uniform_rand < capped_ratio
  171. return accepted
  172. def _get_recovered_probs(
  173. self,
  174. target_probs: torch.Tensor, # [k, vocab_size]
  175. draft_probs: torch.Tensor, # [k, vocab_size]
  176. ) -> torch.Tensor:
  177. r"""Create a probability distribution for each proposed token which can
  178. be sampled if the proposed token is rejected.
  179. When this routine is applied sequentially, the true distribution of the
  180. target model is recovered (within hardware numerics).
  181. The probability distribution used in this rejection case is constructed
  182. as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
  183. :math:`x` given context :math:`x_1, \dots, x_n` according to the target
  184. model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
  185. according to the draft model:
  186. .. math::
  187. x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
  188. where :math:`(f(x))_+` is defined as:
  189. .. math::
  190. (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
  191. Returns a tensor of shape [batch_size, k, vocab_size].
  192. Note: This batches operations on GPU and thus constructs the recovered
  193. distribution for all tokens, even if they are accepted. This causes
  194. division-by-zero errors, so we use self._smallest_positive_value to
  195. avoid that. This introduces some drift to the distribution.
  196. """
  197. _, k, _ = draft_probs.shape
  198. # shape [batch_size, k, vocab_size]
  199. difference = target_probs - draft_probs
  200. # TODO: Can we use logprobs instead of probs, and avoid the
  201. # division-by-zero errors without introducing distribution drift?
  202. # shape [batch_size, k, vocab_size]
  203. f = torch.clamp(difference, min=self._smallest_positive_value)
  204. # shape [batch_size, k, vocab_size]
  205. recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
  206. return recovered_probs
  207. @cached_property
  208. def _smallest_positive_value(self) -> float:
  209. """Return the smallest positive value representable by the probs dtype.
  210. This value is used when constructing a distribution from which to sample
  211. recovered tokens in the first rejection case.
  212. See _get_recovered_probs for more details
  213. Note that this isn't actually the smallest positive value representable
  214. by float32, but the smallest positive normal value.
  215. See https://en.wikipedia.org/wiki/Subnormal_number for more information.
  216. """
  217. return torch.finfo(self.probs_dtype).tiny
  218. # torch.multinomial forces a GPU<->CPU sync.
  219. # Therefore, we use an optimized implementation instead that skips the sync.
  220. # Note that we always sample with replacement.
  221. # probs will be modified in place, but this is fine, as we pass
  222. # in a copy already.
  223. @torch.jit.script
  224. def _multinomial(
  225. probs: torch.Tensor,
  226. num_samples: int,
  227. k: int,
  228. seeded_seqs: Dict[int, torch.Generator],
  229. ) -> torch.Tensor:
  230. if num_samples > 1:
  231. # This is equivalent to torch.repeat_interleaved (which also
  232. # forces a GPU<->CPU sync).
  233. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  234. probs.shape[1]).contiguous().view(
  235. -1, probs.shape[1])
  236. q = torch.empty_like(probs)
  237. if not seeded_seqs:
  238. q.exponential_(1.0)
  239. else:
  240. non_seeded_indices: List[int] = []
  241. start = 0
  242. for idx in range(len(q) // k):
  243. end = start + k
  244. generator = seeded_seqs.get(idx)
  245. if generator is None:
  246. non_seeded_indices.extend(list(range(start, end)))
  247. else:
  248. q[start:end].exponential_(1.0, generator=generator)
  249. start = end
  250. q[non_seeded_indices].exponential_(1.0)
  251. return probs.div_(q).argmax(dim=1).view(-1, num_samples)