spec_decode_base_sampler.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from abc import abstractmethod
  2. from typing import Dict, Optional
  3. import torch
  4. import torch.jit
  5. import torch.nn as nn
  6. class SpecDecodeBaseSampler(nn.Module):
  7. """Base class for samplers used for Speculative Decoding verification
  8. step.
  9. """
  10. def __init__(self,
  11. disable_bonus_tokens: bool = True,
  12. strict_mode: bool = False):
  13. """Base class constructor.
  14. Args:
  15. disable_bonus_tokens: Whether or not to disable the bonus token.
  16. Require when bonus tokens will cause corrupt KV cache for
  17. proposal methods that require KV cache.
  18. strict_mode: Whether or not to perform shape/device/dtype checks
  19. during sampling. This catches correctness issues but adds
  20. nontrivial latency.
  21. """
  22. super().__init__()
  23. self._disable_bonus_tokens = disable_bonus_tokens
  24. self._strict_mode = strict_mode
  25. # NOTE: A "bonus token" is accepted iff all proposal tokens are
  26. # accepted. There is always only one possible bonus token. We store this
  27. # value in a variable for readability.
  28. self._num_bonus_tokens = 1
  29. self.num_accepted_tokens: Optional[torch.Tensor] = None
  30. self.num_emitted_tokens: Optional[torch.Tensor] = None
  31. self.num_draft_tokens: int = 0
  32. def init_gpu_tensors(self, rank: int) -> None:
  33. assert self.num_accepted_tokens is None
  34. device = f"cuda:{rank}"
  35. self.num_accepted_tokens = torch.tensor(0,
  36. dtype=torch.long,
  37. device=device)
  38. self.num_emitted_tokens = torch.tensor(0,
  39. dtype=torch.long,
  40. device=device)
  41. @property
  42. def probs_dtype(self):
  43. return torch.float32
  44. @property
  45. def token_id_dtype(self):
  46. return torch.int64
  47. def _create_output(
  48. self,
  49. accepted: torch.Tensor, # [batch_size, k]
  50. substitute_token_ids: torch.Tensor, # [batch_size, k]
  51. draft_token_ids: torch.Tensor, # [batch_size, k]
  52. bonus_token_ids: torch.Tensor, # [batch_size]
  53. ) -> torch.Tensor:
  54. """Format output. Returns a matrix of token ids. When
  55. a token is rejected via sampling, all subsequent token ids are
  56. set to -1 for the sequence.
  57. Args:
  58. accepted: A boolean tensor indicating if the corresponding
  59. draft token in draft_token_ids should be accepted or not.
  60. substitute_token_ids: A tensor of token_ids that can be used
  61. as substitutes for the draft token ids if the proposed token
  62. is rejected.
  63. draft_token_ids: A tensor of token ids speculated by the
  64. draft model.
  65. bonus_token_ids: Token ids to use as the bonus token if
  66. all the draft tokens are accepted.
  67. Returns:
  68. A tensor containing the accepted token ids. The shape of the
  69. tensor is [batch_size, k + num_bonus_tokens]
  70. """
  71. batch_size, k = substitute_token_ids.shape
  72. bonus_token_ids = bonus_token_ids.squeeze()
  73. # Determine the index of the first False value for each row.
  74. limits = (accepted == 0).max(1).indices
  75. limits[~(accepted == 0).any(1)] = k
  76. # Create masks using the indices.
  77. indices = torch.arange(k, device=accepted.device).unsqueeze(0)
  78. accepted_mask = indices < limits.unsqueeze(1)
  79. after_false_mask = indices == limits.unsqueeze(1)
  80. # Create an extended output tensor
  81. output_with_bonus_tokens = -torch.ones(
  82. (batch_size, k + self._num_bonus_tokens),
  83. dtype=self.token_id_dtype,
  84. device=accepted.device)
  85. output = output_with_bonus_tokens[:, :k]
  86. # Fill in the first k columns of the output tensor using masks and data
  87. # tensors.
  88. output[:, :k] = torch.where(accepted_mask, draft_token_ids,
  89. -torch.ones_like(draft_token_ids))
  90. # Fill the last column.
  91. # We check output directly as accepted may have True values inconsistent
  92. # with causal acceptance.
  93. output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
  94. bonus_token_ids, -1)
  95. # We disable bonus tokens because it causes corrupt KV cache for
  96. # proposal methods that require KV cache. We can fix it by "prefilling"
  97. # the bonus token in the proposer.
  98. if self._disable_bonus_tokens:
  99. output_with_bonus_tokens[:, -1] = -1
  100. # Fill the recovered token ids.
  101. output.mul_(~after_false_mask).add_(
  102. substitute_token_ids.mul(after_false_mask))
  103. self.num_accepted_tokens += accepted.sum()
  104. self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
  105. self.num_draft_tokens += batch_size * k
  106. return output_with_bonus_tokens
  107. def _raise_if_incorrect_input(
  108. self,
  109. target_probs: torch.Tensor,
  110. draft_token_ids: torch.Tensor,
  111. bonus_token_ids: torch.Tensor,
  112. draft_probs: Optional[torch.Tensor] = None,
  113. ) -> None:
  114. self._raise_if_incorrect_shape(target_probs, draft_token_ids,
  115. bonus_token_ids, draft_probs)
  116. self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
  117. bonus_token_ids, draft_probs)
  118. self._raise_if_inconsistent_device(target_probs, draft_token_ids,
  119. bonus_token_ids, draft_probs)
  120. self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
  121. draft_token_ids, bonus_token_ids)
  122. def _raise_if_incorrect_shape(
  123. self,
  124. target_probs: torch.Tensor,
  125. draft_token_ids: torch.Tensor,
  126. bonus_token_ids: torch.Tensor,
  127. draft_probs: Optional[torch.Tensor] = None,
  128. ) -> None:
  129. (target_batch_size, num_target_probs,
  130. target_vocab_size) = target_probs.shape
  131. # validate the shape of draft token ids.
  132. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
  133. assert draft_token_ids_batch_size == target_batch_size
  134. assert num_draft_token_ids == num_target_probs
  135. # validate the shape of bonus token ids
  136. bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
  137. assert bonus_batch_size == target_batch_size
  138. assert num_bonus_tokens == self._num_bonus_tokens
  139. # validate the shape of draft probs if it is set
  140. if draft_probs is not None:
  141. (draft_batch_size, num_draft_probs,
  142. draft_vocab_size) = draft_probs.shape
  143. assert draft_batch_size == target_batch_size
  144. assert num_draft_probs == num_target_probs
  145. assert (draft_vocab_size == target_vocab_size
  146. ), f"{draft_vocab_size=} {target_vocab_size=}"
  147. def _raise_if_incorrect_dtype(
  148. self,
  149. target_probs: torch.Tensor,
  150. draft_token_ids: torch.Tensor,
  151. bonus_token_ids: torch.Tensor,
  152. draft_probs: Optional[torch.Tensor] = None,
  153. ) -> None:
  154. assert target_probs.dtype == self.probs_dtype
  155. assert draft_token_ids.dtype == self.token_id_dtype
  156. assert bonus_token_ids.dtype == self.token_id_dtype
  157. if draft_probs is not None:
  158. assert draft_probs.dtype == self.probs_dtype
  159. def _raise_if_inconsistent_device(
  160. self,
  161. target_probs: torch.Tensor,
  162. draft_token_ids: torch.Tensor,
  163. bonus_token_ids: torch.Tensor,
  164. draft_probs: Optional[torch.Tensor] = None,
  165. ) -> None:
  166. devices = [
  167. t.device for t in
  168. [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
  169. if t is not None
  170. ]
  171. assert all([devices[0] == device for device in devices])
  172. def _raise_if_out_of_bounds_vocab(
  173. self,
  174. vocab_size: int,
  175. draft_token_ids: torch.Tensor,
  176. bonus_token_ids: torch.Tensor,
  177. ) -> None:
  178. assert torch.all(bonus_token_ids < vocab_size)
  179. assert torch.all(bonus_token_ids >= 0)
  180. assert torch.all(draft_token_ids < vocab_size)
  181. assert torch.all(draft_token_ids >= 0)
  182. class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
  183. """Base class for samplers used for Speculative Decoding verification
  184. step which are deterministic.
  185. """
  186. @abstractmethod
  187. def forward(
  188. self,
  189. target_probs: torch.Tensor,
  190. bonus_token_ids: torch.Tensor,
  191. draft_probs: torch.Tensor,
  192. draft_token_ids: torch.Tensor,
  193. ) -> torch.Tensor:
  194. raise NotImplementedError
  195. class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
  196. """Base class for samplers used for Speculative Decoding verification
  197. step which are stochastic
  198. """
  199. @abstractmethod
  200. def forward(
  201. self,
  202. target_probs: torch.Tensor,
  203. bonus_token_ids: torch.Tensor,
  204. draft_probs: torch.Tensor,
  205. draft_token_ids: torch.Tensor,
  206. seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
  207. ) -> torch.Tensor:
  208. raise NotImplementedError