rejection_sampler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from functools import cached_property
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.jit
  5. from aphrodite.modeling.layers.spec_decode_base_sampler import \
  6. SpecDecodeStochasticBaseSampler
  7. class RejectionSampler(SpecDecodeStochasticBaseSampler):
  8. """Apply modified rejection sampling as described in "Accelerating Large
  9. Language Model Decoding with Speculative Sampling"
  10. https://arxiv.org/pdf/2302.01318.pdf.
  11. """
  12. def __init__(self,
  13. disable_bonus_tokens: bool = True,
  14. strict_mode: bool = False):
  15. """Create a rejection sampler.
  16. Args:
  17. disable_bonus_tokens: Whether or not to disable the bonus token.
  18. Require when bonus tokens will cause corrupt KV cache for
  19. proposal methods that require KV cache.
  20. strict_mode: Whether or not to perform shape/device/dtype checks
  21. during sampling. This catches correctness issues but adds
  22. nontrivial latency.
  23. """
  24. super().__init__(disable_bonus_tokens=disable_bonus_tokens,
  25. strict_mode=strict_mode)
  26. def forward(
  27. self,
  28. target_probs: torch.Tensor,
  29. bonus_token_ids: torch.Tensor,
  30. draft_probs: torch.Tensor,
  31. draft_token_ids: torch.Tensor,
  32. generators: List[Optional[torch.Generator]],
  33. ) -> torch.Tensor:
  34. """Sample token ids using rejection sampling. This accepts or rejects
  35. tokens proposed by the draft model using the probability of each token
  36. according to the draft and target models.
  37. In the worst case where all draft tokens are rejected, it is guaranteed
  38. one correct token will be emitted.
  39. In the case where all draft tokens are accepted, a bonus token will be
  40. accepted as its cheap to have the target model score this speculative
  41. sequence.
  42. Args:
  43. target_probs: The probability distribution over token ids given
  44. context according to the target model.
  45. shape = [batch_size, num_speculative_tokens, vocab_size]
  46. bonus_token_ids: The "bonus" token ids that are accepted iff all
  47. speculative tokens in a sequence are accepted.
  48. shape = [batch_size, num_bonus_tokens]
  49. draft_probs: The probability distribution over token ids given
  50. context according to the draft model.
  51. shape = [batch_size, num_speculative_tokens, vocab_size]
  52. draft_token_ids: The token ids that were sampled from the draft
  53. probabilities.
  54. shape = [batch_size, num_speculative_tokens]
  55. Returns:
  56. output_token_ids: The token ids sampled via rejection sampling,
  57. or -1 if unable to sample a token because the previous token
  58. was rejected.
  59. shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
  60. """
  61. # Only perform shape/dtype/device checking in strict mode, as it adds
  62. # overhead.
  63. if self._strict_mode:
  64. self._raise_if_incorrect_input(target_probs, bonus_token_ids,
  65. draft_probs, draft_token_ids)
  66. accepted, recovered_token_ids = (
  67. self._batch_modified_rejection_sampling(
  68. target_probs,
  69. draft_probs,
  70. draft_token_ids,
  71. generators,
  72. ))
  73. output_token_ids = self._create_output(
  74. accepted,
  75. recovered_token_ids,
  76. draft_token_ids,
  77. bonus_token_ids,
  78. )
  79. return output_token_ids
  80. def _batch_modified_rejection_sampling(
  81. self,
  82. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  83. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  84. draft_token_ids: torch.Tensor, # [batch_size, k]
  85. generators: List[Optional[torch.Generator]],
  86. ) -> Tuple[torch.Tensor, torch.Tensor]:
  87. """Perform modified rejection sampling on each sequence.
  88. Returns:
  89. A tuple of two tensors:
  90. 0: A bool tensor of which tokens in each sequence is accepted.
  91. shape = [batch_size, k]
  92. 1: Token ids sampled from a recovered distribution, to be used
  93. when a token is rejected.
  94. shape = [batch_size, k]
  95. """
  96. batch_size, k, vocab_size = draft_probs.shape
  97. # shape [batch_size, k]
  98. accepted = self._get_accepted(target_probs, draft_probs,
  99. draft_token_ids, generators)
  100. recovered_probs = self._get_recovered_probs(
  101. target_probs, draft_probs).reshape(batch_size * k, vocab_size)
  102. seed_indices, non_seed_indices = self._split_batch_by_seeded(
  103. generators, k=k)
  104. # NOTE: the recovered_probs are overwritten by this method.
  105. recovered_token_ids = _multinomial(
  106. recovered_probs,
  107. num_samples=1,
  108. k=k,
  109. generators=generators,
  110. seed_indices=seed_indices,
  111. # this arg is unused when None but torch.jit requires a list
  112. non_seed_indices=non_seed_indices or [],
  113. ).reshape(batch_size, k)
  114. return accepted, recovered_token_ids
  115. def _get_accepted(
  116. self,
  117. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  118. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  119. draft_token_ids: torch.Tensor, # [batch_size, k]
  120. generators: List[Optional[torch.Generator]],
  121. ) -> torch.Tensor:
  122. r"""Create bool matrix over the proposed draft tokens. If
  123. True, then a token can be accepted, else it should be
  124. rejected.
  125. Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
  126. :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
  127. to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
  128. same conditional probability according to the draft model, the token
  129. is accepted with probability:
  130. .. math::
  131. \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
  132. {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
  133. This implementation does not apply causality. When using the output,
  134. if a token is rejected, subsequent tokens should not be used.
  135. Returns a bool tensor of shape [batch_size, k] specifying which tokens
  136. are accepted.
  137. """
  138. batch_size, k, _ = draft_probs.shape
  139. batch_indices = torch.arange(batch_size,
  140. device=target_probs.device)[:, None]
  141. probs_indicies = torch.arange(k, device=target_probs.device)
  142. # shape [batch_size, k]
  143. selected_draft_probs = draft_probs[batch_indices, probs_indicies,
  144. draft_token_ids]
  145. # shape [batch_size, k]
  146. selected_target_probs = target_probs[batch_indices, probs_indicies,
  147. draft_token_ids]
  148. seed_indices, non_seed_indices = self._split_batch_by_seeded(
  149. generators)
  150. if len(seed_indices) == 0:
  151. uniform_rand = torch.rand_like(selected_target_probs)
  152. else:
  153. uniform_rand = torch.empty_like(selected_target_probs)
  154. for idx in seed_indices:
  155. uniform_rand[idx, :] = torch.rand(1,
  156. k,
  157. dtype=self.probs_dtype,
  158. device=target_probs.device,
  159. generator=generators[idx])
  160. if non_seed_indices:
  161. uniform_rand[non_seed_indices, :] = torch.rand(
  162. len(non_seed_indices),
  163. k,
  164. dtype=self.probs_dtype,
  165. device=target_probs.device)
  166. capped_ratio = torch.minimum(
  167. selected_target_probs / selected_draft_probs,
  168. torch.full((1, ), 1, device=target_probs.device))
  169. accepted = uniform_rand < capped_ratio
  170. return accepted
  171. def _get_recovered_probs(
  172. self,
  173. target_probs: torch.Tensor, # [k, vocab_size]
  174. draft_probs: torch.Tensor, # [k, vocab_size]
  175. ) -> torch.Tensor:
  176. r"""Create a probability distribution for each proposed token which can
  177. be sampled if the proposed token is rejected.
  178. When this routine is applied sequentially, the true distribution of the
  179. target model is recovered (within hardware numerics).
  180. The probability distribution used in this rejection case is constructed
  181. as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
  182. :math:`x` given context :math:`x_1, \dots, x_n` according to the target
  183. model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
  184. according to the draft model:
  185. .. math::
  186. x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
  187. where :math:`(f(x))_+` is defined as:
  188. .. math::
  189. (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
  190. Returns a tensor of shape [batch_size, k, vocab_size].
  191. Note: This batches operations on GPU and thus constructs the recovered
  192. distribution for all tokens, even if they are accepted. This causes
  193. division-by-zero errors, so we use self._smallest_positive_value to
  194. avoid that. This introduces some drift to the distribution.
  195. """
  196. _, k, _ = draft_probs.shape
  197. # shape [batch_size, k, vocab_size]
  198. difference = target_probs - draft_probs
  199. # TODO: Can we use logprobs instead of probs, and avoid the
  200. # division-by-zero errors without introducing distribution drift?
  201. # shape [batch_size, k, vocab_size]
  202. f = torch.clamp(difference, min=self._smallest_positive_value)
  203. # shape [batch_size, k, vocab_size]
  204. recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
  205. return recovered_probs
  206. @cached_property
  207. def _smallest_positive_value(self) -> float:
  208. """Return the smallest positive value representable by the probs dtype.
  209. This value is used when constructing a distribution from which to sample
  210. recovered tokens in the first rejection case.
  211. See _get_recovered_probs for more details
  212. Note that this isn't actually the smallest positive value representable
  213. by float32, but the smallest positive normal value.
  214. See https://en.wikipedia.org/wiki/Subnormal_number for more information.
  215. """
  216. return torch.finfo(self.probs_dtype).tiny
  217. # partition batch into indices for which a generator is provided
  218. # and indicies for which no generator is provided
  219. @staticmethod
  220. def _split_batch_by_seeded(
  221. generators: List[Optional[torch.Generator]],
  222. k: int = 1,
  223. ) -> Tuple[List[int], Optional[List[int]]]:
  224. if all(generator is None for generator in generators):
  225. seed_indices: List[int] = []
  226. non_seed_indices: Optional[List[int]] = None
  227. else:
  228. seed_indices, non_seed_indices = [], []
  229. for i, generator in enumerate(generators):
  230. if generator is None:
  231. non_seed_indices.extend(range(k * i, k * (i + 1)))
  232. else:
  233. seed_indices.extend(range(k * i, k * (i + 1)))
  234. return seed_indices, non_seed_indices
  235. # torch.multinomial forces a GPU<->CPU sync.
  236. # Therefore, we use an optimized implementation instead that skips the sync.
  237. # Note that we always sample with replacement.
  238. # probs will be modified in place, but this is fine, as we pass
  239. # in a copy already.
  240. @torch.jit.script
  241. def _multinomial(
  242. probs: torch.Tensor,
  243. num_samples: int,
  244. k: int,
  245. generators: List[Optional[torch.Generator]],
  246. seed_indices: List[int],
  247. non_seed_indices: List[int],
  248. ) -> torch.Tensor:
  249. if num_samples > 1:
  250. # This is equivalent to torch.repeat_interleaved (which also
  251. # forces a GPU<->CPU sync).
  252. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  253. probs.shape[1]).contiguous().view(
  254. -1, probs.shape[1])
  255. q = torch.empty_like(probs)
  256. if len(seed_indices) == 0:
  257. q.exponential_(1.0)
  258. else:
  259. q[non_seed_indices].exponential_(1.0)
  260. for idx in seed_indices:
  261. q[idx].exponential_(1.0, generator=generators[idx // k])
  262. return probs.div_(q).argmax(dim=1).view(-1, num_samples)