rejection.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from typing import Tuple, Optional
  2. from functools import cached_property
  3. import torch
  4. import torch.nn as nn
  5. import torch.jit
  6. class RejectionSampler(nn.Module):
  7. """Apply modified rejection sampling as described in "Accelerating Large
  8. Language Model Decoding with Speculative Sampling"
  9. https://arxiv.org/pdf/2302.01318.pdf.
  10. """
  11. def __init__(self, strict_mode: bool = False):
  12. """Create a rejection sampler.
  13. Args:
  14. strict_mode: Whether or not to perform shape/device/dtype checks
  15. during sampling. This catches correctness issues but adds
  16. nontrivial latency.
  17. """
  18. super().__init__()
  19. self.probs_dtype = torch.float32
  20. self.token_id_dtype = torch.int64
  21. self._num_bonus_tokens = 1
  22. self._strict_mode = strict_mode
  23. self.num_accepted_tokens: Optional[torch.Tensor] = None
  24. self.num_emitted_tokens: Optional[torch.Tensor] = None
  25. self.num_draft_tokens: int = 0
  26. def init_gpu_tensors(self, rank: int) -> None:
  27. assert self.num_accepted_tokens is None
  28. device = f"cuda:{rank}"
  29. self.num_accepted_tokens = torch.tensor(0,
  30. dtype=torch.long,
  31. device=device)
  32. self.num_emitted_tokens = torch.tensor(0,
  33. dtype=torch.long,
  34. device=device)
  35. def forward(
  36. self,
  37. target_probs: torch.Tensor,
  38. bonus_token_ids: torch.Tensor,
  39. draft_probs: torch.Tensor,
  40. draft_token_ids: torch.Tensor,
  41. ) -> torch.Tensor:
  42. """Sample token ids using rejection sampling. This accepts or rejects
  43. tokens proposed by the draft model using the probability of each token
  44. according to the draft and target models.
  45. In the worst case where all draft tokens are rejected, it is guaranteed
  46. one correct token will be emitted.
  47. In the case where all draft tokens are accepted, a bonus token will be
  48. accepted as its cheap to have the target model score this speculative
  49. sequence.
  50. Args:
  51. target_probs: The probability distribution over token ids given
  52. context according to the target model.
  53. shape = [batch_size, num_speculative_tokens, vocab_size]
  54. bonus_token_ids: The "bonus" token ids that are accepted iff all
  55. speculative tokens in a sequence are accepted.
  56. shape = [batch_size, num_bonus_tokens]
  57. draft_probs: The probability distribution over token ids given
  58. context according to the draft model.
  59. shape = [batch_size, num_speculative_tokens, vocab_size]
  60. draft_token_ids: The token ids that were sampled from the draft
  61. probabilities.
  62. shape = [batch_size, num_speculative_tokens]
  63. Returns:
  64. output_token_ids: The token ids sampled via rejection sampling,
  65. or -1 if unable to sample a token because the previous token
  66. was rejected.
  67. shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
  68. """
  69. # Only perform shape/dtype/device checking in strict mode, as it adds
  70. # overhead.
  71. if self._strict_mode:
  72. self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
  73. draft_probs, draft_token_ids)
  74. self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
  75. draft_probs, draft_token_ids)
  76. self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
  77. draft_probs, draft_token_ids)
  78. self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
  79. bonus_token_ids,
  80. draft_token_ids)
  81. accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
  82. target_probs,
  83. draft_probs,
  84. draft_token_ids,
  85. )
  86. output_token_ids = self._create_output(
  87. accepted,
  88. recovered_token_ids,
  89. draft_token_ids,
  90. bonus_token_ids,
  91. )
  92. return output_token_ids
  93. def _batch_modified_rejection_sampling(
  94. self,
  95. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  96. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  97. draft_token_ids: torch.Tensor, # [batch_size, k]
  98. ) -> Tuple[torch.Tensor, torch.Tensor]:
  99. """Perform modified rejection sampling on each sequence.
  100. Returns:
  101. A tuple of two tensors:
  102. 0: A bool tensor of which tokens in each sequence is accepted.
  103. shape = [batch_size, k]
  104. 1: Token ids sampled from a recovered distribution, to be used
  105. when a token is rejected.
  106. shape = [batch_size, k]
  107. """
  108. batch_size, k, vocab_size = draft_probs.shape
  109. # shape [batch_size, k]
  110. accepted = self._get_accepted(target_probs, draft_probs,
  111. draft_token_ids)
  112. recovered_probs = self._get_recovered_probs(
  113. target_probs, draft_probs).reshape(batch_size * k, vocab_size)
  114. recovered_token_ids = _multinomial(recovered_probs,
  115. num_samples=1).reshape(
  116. batch_size, k)
  117. return accepted, recovered_token_ids
  118. def _get_accepted(
  119. self,
  120. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  121. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  122. draft_token_ids: torch.Tensor, # [batch_size, k]
  123. ) -> torch.Tensor:
  124. r"""Create bool matrix over the proposed draft tokens. If
  125. True, then a token can be accepted, else it should be
  126. rejected.
  127. Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
  128. :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
  129. to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
  130. same conditional probability according to the draft model, the token
  131. is accepted with probability:
  132. .. math::
  133. \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
  134. {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
  135. This implementation does not apply causality. When using the output,
  136. if a token is rejected, subsequent tokens should not be used.
  137. Returns a bool tensor of shape [batch_size, k] specifying which tokens
  138. are accepted.
  139. """
  140. batch_size, k, _ = draft_probs.shape
  141. batch_indices = torch.arange(batch_size,
  142. device=target_probs.device)[:, None]
  143. probs_indicies = torch.arange(k, device=target_probs.device)
  144. # shape [batch_size, k]
  145. selected_draft_probs = draft_probs[batch_indices, probs_indicies,
  146. draft_token_ids]
  147. # shape [batch_size, k]
  148. selected_target_probs = target_probs[batch_indices, probs_indicies,
  149. draft_token_ids]
  150. uniform_rand = torch.rand(batch_size,
  151. k,
  152. dtype=self.probs_dtype,
  153. device=target_probs.device)
  154. capped_ratio = torch.minimum(
  155. selected_target_probs / selected_draft_probs,
  156. torch.full((1, ), 1, device=target_probs.device))
  157. accepted = uniform_rand < capped_ratio
  158. return accepted
  159. def _get_recovered_probs(
  160. self,
  161. target_probs: torch.Tensor, # [k, vocab_size]
  162. draft_probs: torch.Tensor, # [k, vocab_size]
  163. ) -> torch.Tensor:
  164. r"""Create a probability distribution for each proposed token which can
  165. be sampled if the proposed token is rejected.
  166. When this routine is applied sequentially, the true distribution of the
  167. target model is recovered (within hardware numerics).
  168. The probability distribution used in this rejection case is constructed
  169. as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
  170. :math:`x` given context :math:`x_1, \dots, x_n` according to the target
  171. model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
  172. according to the draft model:
  173. .. math::
  174. x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
  175. where :math:`(f(x))_+` is defined as:
  176. .. math::
  177. (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
  178. See https://github.com/vllm-project/vllm/pull/2336 for a visualization
  179. of the draft, target, and recovered probability distributions.
  180. Returns a tensor of shape [batch_size, k, vocab_size].
  181. Note: This batches operations on GPU and thus constructs the recovered
  182. distribution for all tokens, even if they are accepted. This causes
  183. division-by-zero errors, so we use self._smallest_positive_value to
  184. avoid that. This introduces some drift to the distribution.
  185. """
  186. _, k, _ = draft_probs.shape
  187. # shape [batch_size, k, vocab_size]
  188. difference = target_probs - draft_probs
  189. # TODO(cade): Can we use logprobs instead of probs, and avoid the
  190. # division-by-zero errors without introducing distribution drift?
  191. # shape [batch_size, k, vocab_size]
  192. f = torch.clamp(difference, min=self._smallest_positive_value)
  193. # shape [batch_size, k, vocab_size]
  194. recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
  195. return recovered_probs
  196. @cached_property
  197. def _smallest_positive_value(self) -> float:
  198. """Return the smallest positive value representable by the probs dtype.
  199. This value is used when constructing a distribution from which to sample
  200. recovered tokens in the first rejection case.
  201. See _get_recovered_probs for more details
  202. Note that this isn't actually the smallest positive value representable
  203. by float32, but the smallest positive normal value.
  204. See https://en.wikipedia.org/wiki/Subnormal_number for more information.
  205. """
  206. return torch.finfo(self.probs_dtype).tiny
  207. def _create_output(
  208. self,
  209. accepted: torch.Tensor, # [batch_size, k]
  210. recovered_token_ids: torch.Tensor, # [batch_size, k]
  211. draft_token_ids: torch.Tensor, # [batch_size, k]
  212. bonus_token_ids: torch.Tensor, # [batch_size]
  213. ) -> torch.Tensor:
  214. """Format output. Returns a matrix of token ids. When
  215. a token is rejected via rejection sampling, all subsequent
  216. token ids are set to -1 for the sequence.
  217. shape = [batch_size, k + num_bonus_tokens]
  218. """
  219. bonus_token_ids = bonus_token_ids.squeeze()
  220. batch_size, k = recovered_token_ids.shape
  221. # Determine the index of the first False value for each row.
  222. limits = (accepted == 0).max(1).indices
  223. limits[~(accepted == 0).any(1)] = k
  224. # Create masks using the indices.
  225. indices = torch.arange(k, device=accepted.device).unsqueeze(0)
  226. accepted_mask = indices < limits.unsqueeze(1)
  227. after_false_mask = indices == limits.unsqueeze(1)
  228. # Create an extended output tensor
  229. output_with_bonus_tokens = -torch.ones(
  230. (batch_size, k + self._num_bonus_tokens),
  231. dtype=self.token_id_dtype,
  232. device=accepted.device)
  233. output = output_with_bonus_tokens[:, :k]
  234. # Fill in the first k columns of the output tensor using masks and data
  235. # tensors.
  236. output[:, :k] = torch.where(accepted_mask, draft_token_ids,
  237. -torch.ones_like(draft_token_ids))
  238. # Fill the last column.
  239. # We check output directly as accepted may have True values inconsistent
  240. # with causal acceptance.
  241. output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
  242. bonus_token_ids, -1)
  243. # Fill the recovered token ids.
  244. output.mul_(~after_false_mask).add_(
  245. recovered_token_ids.mul(after_false_mask))
  246. self.num_accepted_tokens += accepted.sum()
  247. self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
  248. self.num_draft_tokens += batch_size * k
  249. return output_with_bonus_tokens
  250. def _raise_if_incorrect_shape(
  251. self,
  252. target_probs: torch.Tensor,
  253. bonus_token_ids: torch.Tensor,
  254. draft_probs: torch.Tensor,
  255. draft_token_ids: torch.Tensor,
  256. ) -> None:
  257. (target_batch_size, num_target_probs,
  258. target_vocab_size) = target_probs.shape
  259. bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
  260. draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
  261. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
  262. assert draft_batch_size == target_batch_size
  263. assert num_draft_probs == num_target_probs
  264. assert (draft_vocab_size == target_vocab_size
  265. ), f"{draft_vocab_size=} {target_vocab_size=}"
  266. assert draft_token_ids_batch_size == draft_batch_size
  267. assert num_draft_token_ids == num_draft_probs
  268. assert bonus_batch_size == target_batch_size
  269. assert num_bonus_tokens == self._num_bonus_tokens
  270. def _raise_if_incorrect_dtype(
  271. self,
  272. target_probs: torch.Tensor,
  273. bonus_token_ids: torch.Tensor,
  274. draft_probs: torch.Tensor,
  275. draft_token_ids: torch.Tensor,
  276. ) -> None:
  277. assert all(probs.dtype == self.probs_dtype
  278. for probs in [target_probs, draft_probs])
  279. assert all(token_ids.dtype == self.token_id_dtype
  280. for token_ids in [bonus_token_ids, draft_token_ids])
  281. def _raise_if_inconsistent_device(
  282. self,
  283. target_probs: torch.Tensor,
  284. bonus_token_ids: torch.Tensor,
  285. draft_probs: torch.Tensor,
  286. draft_token_ids: torch.Tensor,
  287. ) -> None:
  288. devices = [
  289. t.device for t in
  290. [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
  291. ]
  292. # pylint: disable=use-a-generator
  293. assert all([devices[0] == device for device in devices])
  294. def _raise_if_out_of_bounds_vocab(
  295. self,
  296. vocab_size: int,
  297. bonus_token_ids: torch.Tensor,
  298. draft_token_ids: torch.Tensor,
  299. ) -> None:
  300. assert torch.all(bonus_token_ids < vocab_size)
  301. assert torch.all(bonus_token_ids >= 0)
  302. assert torch.all(draft_token_ids < vocab_size)
  303. assert torch.all(draft_token_ids >= 0)
  304. # torch.multinomial forces a GPU<->CPU sync.
  305. # Therefore, we use an optimized implementation instead that skips the sync.
  306. # Note that we always sample with replacement.
  307. # probs will be modified in place, but this is fine, as we pass
  308. # in a copy already.
  309. @torch.jit.script
  310. def _multinomial(
  311. probs: torch.Tensor,
  312. num_samples: int,
  313. ) -> torch.Tensor:
  314. if num_samples > 1:
  315. # This is equivalent to torch.repeat_interleaved (which also
  316. # forces a GPU<->CPU sync).
  317. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  318. probs.shape[1]).contiguous().view(
  319. -1, probs.shape[1])
  320. q = torch.empty_like(probs).exponential_(1.0)
  321. return probs.div_(q).argmax(dim=1).view(-1, num_samples)