rejection.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from functools import cached_property
  2. from typing import Optional, Tuple
  3. import torch
  4. import torch.jit
  5. import torch.nn as nn
  6. class RejectionSampler(nn.Module):
  7. """Apply modified rejection sampling as described in "Accelerating Large
  8. Language Model Decoding with Speculative Sampling"
  9. https://arxiv.org/pdf/2302.01318.pdf.
  10. """
  11. def __init__(self,
  12. disable_bonus_tokens: bool = True,
  13. strict_mode: bool = False):
  14. """Create a rejection sampler.
  15. Args:
  16. disable_bonus_tokens: Whether or not to disable the bonus token.
  17. Require when bonus tokens will cause corrupt KV cache for
  18. proposal methods that require KV cache.
  19. strict_mode: Whether or not to perform shape/device/dtype checks
  20. during sampling. This catches correctness issues but adds
  21. nontrivial latency.
  22. """
  23. super().__init__()
  24. self._disable_bonus_tokens = disable_bonus_tokens
  25. self._strict_mode = strict_mode
  26. # NOTE: A "bonus token" is accepted iff all proposal tokens are
  27. # accepted. There is always only one possible bonus token. We store this
  28. # value in a variable for readability.
  29. self._num_bonus_tokens = 1
  30. self.num_accepted_tokens: Optional[torch.Tensor] = None
  31. self.num_emitted_tokens: Optional[torch.Tensor] = None
  32. self.num_draft_tokens: int = 0
  33. def init_gpu_tensors(self, rank: int) -> None:
  34. assert self.num_accepted_tokens is None
  35. device = f"cuda:{rank}"
  36. self.num_accepted_tokens = torch.tensor(0,
  37. dtype=torch.long,
  38. device=device)
  39. self.num_emitted_tokens = torch.tensor(0,
  40. dtype=torch.long,
  41. device=device)
  42. @property
  43. def probs_dtype(self):
  44. return torch.float32
  45. @property
  46. def token_id_dtype(self):
  47. return torch.int64
  48. def forward(
  49. self,
  50. target_probs: torch.Tensor,
  51. bonus_token_ids: torch.Tensor,
  52. draft_probs: torch.Tensor,
  53. draft_token_ids: torch.Tensor,
  54. ) -> torch.Tensor:
  55. """Sample token ids using rejection sampling. This accepts or rejects
  56. tokens proposed by the draft model using the probability of each token
  57. according to the draft and target models.
  58. In the worst case where all draft tokens are rejected, it is guaranteed
  59. one correct token will be emitted.
  60. In the case where all draft tokens are accepted, a bonus token will be
  61. accepted as its cheap to have the target model score this speculative
  62. sequence.
  63. Args:
  64. target_probs: The probability distribution over token ids given
  65. context according to the target model.
  66. shape = [batch_size, num_speculative_tokens, vocab_size]
  67. bonus_token_ids: The "bonus" token ids that are accepted iff all
  68. speculative tokens in a sequence are accepted.
  69. shape = [batch_size, num_bonus_tokens]
  70. draft_probs: The probability distribution over token ids given
  71. context according to the draft model.
  72. shape = [batch_size, num_speculative_tokens, vocab_size]
  73. draft_token_ids: The token ids that were sampled from the draft
  74. probabilities.
  75. shape = [batch_size, num_speculative_tokens]
  76. Returns:
  77. output_token_ids: The token ids sampled via rejection sampling,
  78. or -1 if unable to sample a token because the previous token
  79. was rejected.
  80. shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
  81. """
  82. # Only perform shape/dtype/device checking in strict mode, as it adds
  83. # overhead.
  84. if self._strict_mode:
  85. self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
  86. draft_probs, draft_token_ids)
  87. self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
  88. draft_probs, draft_token_ids)
  89. self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
  90. draft_probs, draft_token_ids)
  91. self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
  92. bonus_token_ids,
  93. draft_token_ids)
  94. accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
  95. target_probs,
  96. draft_probs,
  97. draft_token_ids,
  98. )
  99. output_token_ids = self._create_output(
  100. accepted,
  101. recovered_token_ids,
  102. draft_token_ids,
  103. bonus_token_ids,
  104. )
  105. return output_token_ids
  106. def _batch_modified_rejection_sampling(
  107. self,
  108. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  109. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  110. draft_token_ids: torch.Tensor, # [batch_size, k]
  111. ) -> Tuple[torch.Tensor, torch.Tensor]:
  112. """Perform modified rejection sampling on each sequence.
  113. Returns:
  114. A tuple of two tensors:
  115. 0: A bool tensor of which tokens in each sequence is accepted.
  116. shape = [batch_size, k]
  117. 1: Token ids sampled from a recovered distribution, to be used
  118. when a token is rejected.
  119. shape = [batch_size, k]
  120. """
  121. batch_size, k, vocab_size = draft_probs.shape
  122. # shape [batch_size, k]
  123. accepted = self._get_accepted(target_probs, draft_probs,
  124. draft_token_ids)
  125. recovered_probs = self._get_recovered_probs(
  126. target_probs, draft_probs).reshape(batch_size * k, vocab_size)
  127. # NOTE: the recovered_probs are overwritten by this method.
  128. recovered_token_ids = _multinomial(recovered_probs,
  129. num_samples=1).reshape(
  130. batch_size, k)
  131. return accepted, recovered_token_ids
  132. def _get_accepted(
  133. self,
  134. target_probs: torch.Tensor, # [batch_size, k, vocab_size]
  135. draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
  136. draft_token_ids: torch.Tensor, # [batch_size, k]
  137. ) -> torch.Tensor:
  138. r"""Create bool matrix over the proposed draft tokens. If
  139. True, then a token can be accepted, else it should be
  140. rejected.
  141. Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
  142. :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
  143. to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
  144. same conditional probability according to the draft model, the token
  145. is accepted with probability:
  146. .. math::
  147. \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
  148. {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
  149. This implementation does not apply causality. When using the output,
  150. if a token is rejected, subsequent tokens should not be used.
  151. Returns a bool tensor of shape [batch_size, k] specifying which tokens
  152. are accepted.
  153. """
  154. batch_size, k, _ = draft_probs.shape
  155. batch_indices = torch.arange(batch_size,
  156. device=target_probs.device)[:, None]
  157. probs_indicies = torch.arange(k, device=target_probs.device)
  158. # shape [batch_size, k]
  159. selected_draft_probs = draft_probs[batch_indices, probs_indicies,
  160. draft_token_ids]
  161. # shape [batch_size, k]
  162. selected_target_probs = target_probs[batch_indices, probs_indicies,
  163. draft_token_ids]
  164. uniform_rand = torch.rand(batch_size,
  165. k,
  166. dtype=self.probs_dtype,
  167. device=target_probs.device)
  168. capped_ratio = torch.minimum(
  169. selected_target_probs / selected_draft_probs,
  170. torch.full((1, ), 1, device=target_probs.device))
  171. accepted = uniform_rand < capped_ratio
  172. return accepted
  173. def _get_recovered_probs(
  174. self,
  175. target_probs: torch.Tensor, # [k, vocab_size]
  176. draft_probs: torch.Tensor, # [k, vocab_size]
  177. ) -> torch.Tensor:
  178. r"""Create a probability distribution for each proposed token which can
  179. be sampled if the proposed token is rejected.
  180. When this routine is applied sequentially, the true distribution of the
  181. target model is recovered (within hardware numerics).
  182. The probability distribution used in this rejection case is constructed
  183. as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
  184. :math:`x` given context :math:`x_1, \dots, x_n` according to the target
  185. model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
  186. according to the draft model:
  187. .. math::
  188. x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
  189. where :math:`(f(x))_+` is defined as:
  190. .. math::
  191. (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
  192. Returns a tensor of shape [batch_size, k, vocab_size].
  193. Note: This batches operations on GPU and thus constructs the recovered
  194. distribution for all tokens, even if they are accepted. This causes
  195. division-by-zero errors, so we use self._smallest_positive_value to
  196. avoid that. This introduces some drift to the distribution.
  197. """
  198. _, k, _ = draft_probs.shape
  199. # shape [batch_size, k, vocab_size]
  200. difference = target_probs - draft_probs
  201. # TODO: Can we use logprobs instead of probs, and avoid the
  202. # division-by-zero errors without introducing distribution drift?
  203. # shape [batch_size, k, vocab_size]
  204. f = torch.clamp(difference, min=self._smallest_positive_value)
  205. # shape [batch_size, k, vocab_size]
  206. recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
  207. return recovered_probs
  208. @cached_property
  209. def _smallest_positive_value(self) -> float:
  210. """Return the smallest positive value representable by the probs dtype.
  211. This value is used when constructing a distribution from which to sample
  212. recovered tokens in the first rejection case.
  213. See _get_recovered_probs for more details
  214. Note that this isn't actually the smallest positive value representable
  215. by float32, but the smallest positive normal value.
  216. See https://en.wikipedia.org/wiki/Subnormal_number for more information.
  217. """
  218. return torch.finfo(self.probs_dtype).tiny
  219. def _create_output(
  220. self,
  221. accepted: torch.Tensor, # [batch_size, k]
  222. recovered_token_ids: torch.Tensor, # [batch_size, k]
  223. draft_token_ids: torch.Tensor, # [batch_size, k]
  224. bonus_token_ids: torch.Tensor, # [batch_size]
  225. ) -> torch.Tensor:
  226. """Format output. Returns a matrix of token ids. When
  227. a token is rejected via rejection sampling, all subsequent
  228. token ids are set to -1 for the sequence.
  229. shape = [batch_size, k + num_bonus_tokens]
  230. """
  231. bonus_token_ids = bonus_token_ids.squeeze()
  232. batch_size, k = recovered_token_ids.shape
  233. # Determine the index of the first False value for each row.
  234. limits = (accepted == 0).max(1).indices
  235. limits[~(accepted == 0).any(1)] = k
  236. # Create masks using the indices.
  237. indices = torch.arange(k, device=accepted.device).unsqueeze(0)
  238. accepted_mask = indices < limits.unsqueeze(1)
  239. after_false_mask = indices == limits.unsqueeze(1)
  240. # Create an extended output tensor
  241. output_with_bonus_tokens = -torch.ones(
  242. (batch_size, k + self._num_bonus_tokens),
  243. dtype=self.token_id_dtype,
  244. device=accepted.device)
  245. output = output_with_bonus_tokens[:, :k]
  246. # Fill in the first k columns of the output tensor using masks and data
  247. # tensors.
  248. output[:, :k] = torch.where(accepted_mask, draft_token_ids,
  249. -torch.ones_like(draft_token_ids))
  250. # Fill the last column.
  251. # We check output directly as accepted may have True values inconsistent
  252. # with causal acceptance.
  253. output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
  254. bonus_token_ids, -1)
  255. # We disable bonus tokens because it causes corrupt KV cache for
  256. # proposal methods that require KV cache. We can fix it by "prefilling"
  257. # the bonus token in the proposer.
  258. if self._disable_bonus_tokens:
  259. output_with_bonus_tokens[:, -1] = -1
  260. # Fill the recovered token ids.
  261. output.mul_(~after_false_mask).add_(
  262. recovered_token_ids.mul(after_false_mask))
  263. self.num_accepted_tokens += accepted.sum()
  264. self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
  265. self.num_draft_tokens += batch_size * k
  266. return output_with_bonus_tokens
  267. def _raise_if_incorrect_shape(
  268. self,
  269. target_probs: torch.Tensor,
  270. bonus_token_ids: torch.Tensor,
  271. draft_probs: torch.Tensor,
  272. draft_token_ids: torch.Tensor,
  273. ) -> None:
  274. (target_batch_size, num_target_probs,
  275. target_vocab_size) = target_probs.shape
  276. bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
  277. draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
  278. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
  279. assert draft_batch_size == target_batch_size
  280. assert num_draft_probs == num_target_probs
  281. assert (draft_vocab_size == target_vocab_size
  282. ), f"{draft_vocab_size=} {target_vocab_size=}"
  283. assert draft_token_ids_batch_size == draft_batch_size
  284. assert num_draft_token_ids == num_draft_probs
  285. assert bonus_batch_size == target_batch_size
  286. assert num_bonus_tokens == self._num_bonus_tokens
  287. def _raise_if_incorrect_dtype(
  288. self,
  289. target_probs: torch.Tensor,
  290. bonus_token_ids: torch.Tensor,
  291. draft_probs: torch.Tensor,
  292. draft_token_ids: torch.Tensor,
  293. ) -> None:
  294. assert all(probs.dtype == self.probs_dtype
  295. for probs in [target_probs, draft_probs])
  296. assert all(token_ids.dtype == self.token_id_dtype
  297. for token_ids in [bonus_token_ids, draft_token_ids])
  298. def _raise_if_inconsistent_device(
  299. self,
  300. target_probs: torch.Tensor,
  301. bonus_token_ids: torch.Tensor,
  302. draft_probs: torch.Tensor,
  303. draft_token_ids: torch.Tensor,
  304. ) -> None:
  305. devices = [
  306. t.device for t in
  307. [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
  308. ]
  309. assert all([devices[0] == device for device in devices])
  310. def _raise_if_out_of_bounds_vocab(
  311. self,
  312. vocab_size: int,
  313. bonus_token_ids: torch.Tensor,
  314. draft_token_ids: torch.Tensor,
  315. ) -> None:
  316. assert torch.all(bonus_token_ids < vocab_size)
  317. assert torch.all(bonus_token_ids >= 0)
  318. assert torch.all(draft_token_ids < vocab_size)
  319. assert torch.all(draft_token_ids >= 0)
  320. # torch.multinomial forces a GPU<->CPU sync.
  321. # Therefore, we use an optimized implementation instead that skips the sync.
  322. # Note that we always sample with replacement.
  323. # probs will be modified in place, but this is fine, as we pass
  324. # in a copy already.
  325. @torch.jit.script
  326. def _multinomial(
  327. probs: torch.Tensor,
  328. num_samples: int,
  329. ) -> torch.Tensor:
  330. if num_samples > 1:
  331. # This is equivalent to torch.repeat_interleaved (which also
  332. # forces a GPU<->CPU sync).
  333. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  334. probs.shape[1]).contiguous().view(
  335. -1, probs.shape[1])
  336. q = torch.empty_like(probs).exponential_(1.0)
  337. return probs.div_(q).argmax(dim=1).view(-1, num_samples)