123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401 |
- from functools import cached_property
- from typing import Optional, Tuple
- import torch
- import torch.jit
- import torch.nn as nn
- class RejectionSampler(nn.Module):
- """Apply modified rejection sampling as described in "Accelerating Large
- Language Model Decoding with Speculative Sampling"
- https://arxiv.org/pdf/2302.01318.pdf.
- """
- def __init__(self, strict_mode: bool = False):
- """Create a rejection sampler.
- Args:
- strict_mode: Whether or not to perform shape/device/dtype checks
- during sampling. This catches correctness issues but adds
- nontrivial latency.
- """
- super().__init__()
- self._strict_mode = strict_mode
- # NOTE: A "bonus token" is accepted iff all proposal tokens are
- # accepted. There is always only one possible bonus token. We store this
- # value in a variable for readability.
- self._num_bonus_tokens = 1
- self.num_accepted_tokens: Optional[torch.Tensor] = None
- self.num_emitted_tokens: Optional[torch.Tensor] = None
- self.num_draft_tokens: int = 0
- def init_gpu_tensors(self, rank: int) -> None:
- assert self.num_accepted_tokens is None
- device = f"cuda:{rank}"
- self.num_accepted_tokens = torch.tensor(0,
- dtype=torch.long,
- device=device)
- self.num_emitted_tokens = torch.tensor(0,
- dtype=torch.long,
- device=device)
- @property
- def probs_dtype(self):
- return torch.float32
- @property
- def token_id_dtype(self):
- return torch.int64
- def forward(
- self,
- target_probs: torch.Tensor,
- bonus_token_ids: torch.Tensor,
- draft_probs: torch.Tensor,
- draft_token_ids: torch.Tensor,
- ) -> torch.Tensor:
- """Sample token ids using rejection sampling. This accepts or rejects
- tokens proposed by the draft model using the probability of each token
- according to the draft and target models.
- In the worst case where all draft tokens are rejected, it is guaranteed
- one correct token will be emitted.
- In the case where all draft tokens are accepted, a bonus token will be
- accepted as its cheap to have the target model score this speculative
- sequence.
- Args:
- target_probs: The probability distribution over token ids given
- context according to the target model.
- shape = [batch_size, num_speculative_tokens, vocab_size]
- bonus_token_ids: The "bonus" token ids that are accepted iff all
- speculative tokens in a sequence are accepted.
- shape = [batch_size, num_bonus_tokens]
- draft_probs: The probability distribution over token ids given
- context according to the draft model.
- shape = [batch_size, num_speculative_tokens, vocab_size]
- draft_token_ids: The token ids that were sampled from the draft
- probabilities.
- shape = [batch_size, num_speculative_tokens]
- Returns:
- output_token_ids: The token ids sampled via rejection sampling,
- or -1 if unable to sample a token because the previous token
- was rejected.
- shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
- """
- # Only perform shape/dtype/device checking in strict mode, as it adds
- # overhead.
- if self._strict_mode:
- self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
- draft_probs, draft_token_ids)
- self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
- draft_probs, draft_token_ids)
- self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
- draft_probs, draft_token_ids)
- self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
- bonus_token_ids,
- draft_token_ids)
- accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
- target_probs,
- draft_probs,
- draft_token_ids,
- )
- output_token_ids = self._create_output(
- accepted,
- recovered_token_ids,
- draft_token_ids,
- bonus_token_ids,
- )
- return output_token_ids
- def _batch_modified_rejection_sampling(
- self,
- target_probs: torch.Tensor, # [batch_size, k, vocab_size]
- draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
- draft_token_ids: torch.Tensor, # [batch_size, k]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Perform modified rejection sampling on each sequence.
- Returns:
- A tuple of two tensors:
- 0: A bool tensor of which tokens in each sequence is accepted.
- shape = [batch_size, k]
- 1: Token ids sampled from a recovered distribution, to be used
- when a token is rejected.
- shape = [batch_size, k]
- """
- batch_size, k, vocab_size = draft_probs.shape
- # shape [batch_size, k]
- accepted = self._get_accepted(target_probs, draft_probs,
- draft_token_ids)
- recovered_probs = self._get_recovered_probs(
- target_probs, draft_probs).reshape(batch_size * k, vocab_size)
- # NOTE: the recovered_probs are overwritten by this method.
- recovered_token_ids = _multinomial(recovered_probs,
- num_samples=1).reshape(
- batch_size, k)
- return accepted, recovered_token_ids
- def _get_accepted(
- self,
- target_probs: torch.Tensor, # [batch_size, k, vocab_size]
- draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
- draft_token_ids: torch.Tensor, # [batch_size, k]
- ) -> torch.Tensor:
- r"""Create bool matrix over the proposed draft tokens. If
- True, then a token can be accepted, else it should be
- rejected.
- Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
- :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
- to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
- same conditional probability according to the draft model, the token
- is accepted with probability:
- .. math::
- \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
- {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
- This implementation does not apply causality. When using the output,
- if a token is rejected, subsequent tokens should not be used.
- Returns a bool tensor of shape [batch_size, k] specifying which tokens
- are accepted.
- """
- batch_size, k, _ = draft_probs.shape
- batch_indices = torch.arange(batch_size,
- device=target_probs.device)[:, None]
- probs_indicies = torch.arange(k, device=target_probs.device)
- # shape [batch_size, k]
- selected_draft_probs = draft_probs[batch_indices, probs_indicies,
- draft_token_ids]
- # shape [batch_size, k]
- selected_target_probs = target_probs[batch_indices, probs_indicies,
- draft_token_ids]
- uniform_rand = torch.rand(batch_size,
- k,
- dtype=self.probs_dtype,
- device=target_probs.device)
- capped_ratio = torch.minimum(
- selected_target_probs / selected_draft_probs,
- torch.full((1, ), 1, device=target_probs.device))
- accepted = uniform_rand < capped_ratio
- return accepted
- def _get_recovered_probs(
- self,
- target_probs: torch.Tensor, # [k, vocab_size]
- draft_probs: torch.Tensor, # [k, vocab_size]
- ) -> torch.Tensor:
- r"""Create a probability distribution for each proposed token which can
- be sampled if the proposed token is rejected.
- When this routine is applied sequentially, the true distribution of the
- target model is recovered (within hardware numerics).
- The probability distribution used in this rejection case is constructed
- as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
- :math:`x` given context :math:`x_1, \dots, x_n` according to the target
- model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
- according to the draft model:
- .. math::
- x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
- where :math:`(f(x))_+` is defined as:
- .. math::
- (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
- Returns a tensor of shape [batch_size, k, vocab_size].
- Note: This batches operations on GPU and thus constructs the recovered
- distribution for all tokens, even if they are accepted. This causes
- division-by-zero errors, so we use self._smallest_positive_value to
- avoid that. This introduces some drift to the distribution.
- """
- _, k, _ = draft_probs.shape
- # shape [batch_size, k, vocab_size]
- difference = target_probs - draft_probs
- # TODO: Can we use logprobs instead of probs, and avoid the
- # division-by-zero errors without introducing distribution drift?
- # shape [batch_size, k, vocab_size]
- f = torch.clamp(difference, min=self._smallest_positive_value)
- # shape [batch_size, k, vocab_size]
- recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
- return recovered_probs
- @cached_property
- def _smallest_positive_value(self) -> float:
- """Return the smallest positive value representable by the probs dtype.
- This value is used when constructing a distribution from which to sample
- recovered tokens in the first rejection case.
- See _get_recovered_probs for more details
- Note that this isn't actually the smallest positive value representable
- by float32, but the smallest positive normal value.
- See https://en.wikipedia.org/wiki/Subnormal_number for more information.
- """
- return torch.finfo(self.probs_dtype).tiny
- def _create_output(
- self,
- accepted: torch.Tensor, # [batch_size, k]
- recovered_token_ids: torch.Tensor, # [batch_size, k]
- draft_token_ids: torch.Tensor, # [batch_size, k]
- bonus_token_ids: torch.Tensor, # [batch_size]
- ) -> torch.Tensor:
- """Format output. Returns a matrix of token ids. When
- a token is rejected via rejection sampling, all subsequent
- token ids are set to -1 for the sequence.
- shape = [batch_size, k + num_bonus_tokens]
- """
- bonus_token_ids = bonus_token_ids.squeeze()
- batch_size, k = recovered_token_ids.shape
- # Determine the index of the first False value for each row.
- limits = (accepted == 0).max(1).indices
- limits[~(accepted == 0).any(1)] = k
- # Create masks using the indices.
- indices = torch.arange(k, device=accepted.device).unsqueeze(0)
- accepted_mask = indices < limits.unsqueeze(1)
- after_false_mask = indices == limits.unsqueeze(1)
- # Create an extended output tensor
- output_with_bonus_tokens = -torch.ones(
- (batch_size, k + self._num_bonus_tokens),
- dtype=self.token_id_dtype,
- device=accepted.device)
- output = output_with_bonus_tokens[:, :k]
- # Fill in the first k columns of the output tensor using masks and data
- # tensors.
- output[:, :k] = torch.where(accepted_mask, draft_token_ids,
- -torch.ones_like(draft_token_ids))
- # Fill the last column.
- # We check output directly as accepted may have True values inconsistent
- # with causal acceptance.
- output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
- bonus_token_ids, -1)
- # We disable bonus tokens because it causes corrupt KV cache for
- # proposal methods that require KV cache. We can fix it by "prefilling"
- # the bonus token in the proposer.
- output_with_bonus_tokens[:, -1] = -1
- # Fill the recovered token ids.
- output.mul_(~after_false_mask).add_(
- recovered_token_ids.mul(after_false_mask))
- self.num_accepted_tokens += accepted.sum()
- self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
- self.num_draft_tokens += batch_size * k
- return output_with_bonus_tokens
- def _raise_if_incorrect_shape(
- self,
- target_probs: torch.Tensor,
- bonus_token_ids: torch.Tensor,
- draft_probs: torch.Tensor,
- draft_token_ids: torch.Tensor,
- ) -> None:
- (target_batch_size, num_target_probs,
- target_vocab_size) = target_probs.shape
- bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
- draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
- draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
- assert draft_batch_size == target_batch_size
- assert num_draft_probs == num_target_probs
- assert (draft_vocab_size == target_vocab_size
- ), f"{draft_vocab_size=} {target_vocab_size=}"
- assert draft_token_ids_batch_size == draft_batch_size
- assert num_draft_token_ids == num_draft_probs
- assert bonus_batch_size == target_batch_size
- assert num_bonus_tokens == self._num_bonus_tokens
- def _raise_if_incorrect_dtype(
- self,
- target_probs: torch.Tensor,
- bonus_token_ids: torch.Tensor,
- draft_probs: torch.Tensor,
- draft_token_ids: torch.Tensor,
- ) -> None:
- assert all(probs.dtype == self.probs_dtype
- for probs in [target_probs, draft_probs])
- assert all(token_ids.dtype == self.token_id_dtype
- for token_ids in [bonus_token_ids, draft_token_ids])
- def _raise_if_inconsistent_device(
- self,
- target_probs: torch.Tensor,
- bonus_token_ids: torch.Tensor,
- draft_probs: torch.Tensor,
- draft_token_ids: torch.Tensor,
- ) -> None:
- devices = [
- t.device for t in
- [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
- ]
- assert all([devices[0] == device for device in devices])
- def _raise_if_out_of_bounds_vocab(
- self,
- vocab_size: int,
- bonus_token_ids: torch.Tensor,
- draft_token_ids: torch.Tensor,
- ) -> None:
- assert torch.all(bonus_token_ids < vocab_size)
- assert torch.all(bonus_token_ids >= 0)
- assert torch.all(draft_token_ids < vocab_size)
- assert torch.all(draft_token_ids >= 0)
- # torch.multinomial forces a GPU<->CPU sync.
- # Therefore, we use an optimized implementation instead that skips the sync.
- # Note that we always sample with replacement.
- # probs will be modified in place, but this is fine, as we pass
- # in a copy already.
- @torch.jit.script
- def _multinomial(
- probs: torch.Tensor,
- num_samples: int,
- ) -> torch.Tensor:
- if num_samples > 1:
- # This is equivalent to torch.repeat_interleaved (which also
- # forces a GPU<->CPU sync).
- probs = probs[:, None, :].expand(probs.shape[0], num_samples,
- probs.shape[1]).contiguous().view(
- -1, probs.shape[1])
- q = torch.empty_like(probs).exponential_(1.0)
- return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|