1
0

spec_decode_base_sampler.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. from abc import abstractmethod
  2. from typing import Dict, Optional, Union
  3. import torch
  4. import torch.jit
  5. import torch.nn as nn
  6. class SpecDecodeBaseSampler(nn.Module):
  7. """Base class for samplers used for Speculative Decoding verification
  8. step.
  9. """
  10. def __init__(self,
  11. disable_bonus_tokens: bool = True,
  12. strict_mode: bool = False):
  13. """Base class constructor.
  14. Args:
  15. disable_bonus_tokens: Whether or not to disable the bonus token.
  16. Require when bonus tokens will cause corrupt KV cache for
  17. proposal methods that require KV cache.
  18. strict_mode: Whether or not to perform shape/device/dtype checks
  19. during sampling. This catches correctness issues but adds
  20. nontrivial latency.
  21. """
  22. super().__init__()
  23. self._disable_bonus_tokens = disable_bonus_tokens
  24. self._strict_mode = strict_mode
  25. # NOTE: A "bonus token" is accepted iff all proposal tokens are
  26. # accepted. There is always only one possible bonus token. We store this
  27. # value in a variable for readability.
  28. self._num_bonus_tokens = 1
  29. self.num_accepted_tokens: Optional[torch.Tensor] = None
  30. self.num_emitted_tokens: Optional[torch.Tensor] = None
  31. self.num_draft_tokens: int = 0
  32. def init_gpu_tensors(self, device: Union[int, str]) -> None:
  33. assert self.num_accepted_tokens is None
  34. if isinstance(device, int):
  35. device = f"cuda:{device}"
  36. elif not isinstance(device, str):
  37. raise ValueError(f"Device must be int or str, get {type(device)}")
  38. self.num_accepted_tokens = torch.tensor(0,
  39. dtype=torch.long,
  40. device=device)
  41. self.num_emitted_tokens = torch.tensor(0,
  42. dtype=torch.long,
  43. device=device)
  44. @property
  45. def probs_dtype(self):
  46. return torch.float32
  47. @property
  48. def token_id_dtype(self):
  49. return torch.int64
  50. def _create_output(
  51. self,
  52. accepted: torch.Tensor, # [batch_size, k]
  53. substitute_token_ids: torch.Tensor, # [batch_size, k]
  54. draft_token_ids: torch.Tensor, # [batch_size, k]
  55. bonus_token_ids: torch.Tensor, # [batch_size]
  56. ) -> torch.Tensor:
  57. """Format output. Returns a matrix of token ids. When
  58. a token is rejected via sampling, all subsequent token ids are
  59. set to -1 for the sequence.
  60. Args:
  61. accepted: A boolean tensor indicating if the corresponding
  62. draft token in draft_token_ids should be accepted or not.
  63. substitute_token_ids: A tensor of token_ids that can be used
  64. as substitutes for the draft token ids if the proposed token
  65. is rejected.
  66. draft_token_ids: A tensor of token ids speculated by the
  67. draft model.
  68. bonus_token_ids: Token ids to use as the bonus token if
  69. all the draft tokens are accepted.
  70. Returns:
  71. A tensor containing the accepted token ids. The shape of the
  72. tensor is [batch_size, k + num_bonus_tokens]
  73. """
  74. batch_size, k = substitute_token_ids.shape
  75. bonus_token_ids = bonus_token_ids.squeeze()
  76. # Determine the index of the first False value for each row.
  77. limits = (accepted == 0).max(1).indices
  78. limits[~(accepted == 0).any(1)] = k
  79. # Create masks using the indices.
  80. indices = torch.arange(k, device=accepted.device).unsqueeze(0)
  81. accepted_mask = indices < limits.unsqueeze(1)
  82. after_false_mask = indices == limits.unsqueeze(1)
  83. # Create an extended output tensor
  84. output_with_bonus_tokens = -torch.ones(
  85. (batch_size, k + self._num_bonus_tokens),
  86. dtype=self.token_id_dtype,
  87. device=accepted.device)
  88. output = output_with_bonus_tokens[:, :k]
  89. # Fill in the first k columns of the output tensor using masks and data
  90. # tensors.
  91. output[:, :k] = torch.where(accepted_mask, draft_token_ids,
  92. -torch.ones_like(draft_token_ids))
  93. # Fill the last column.
  94. # We check output directly as accepted may have True values inconsistent
  95. # with causal acceptance.
  96. output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
  97. bonus_token_ids, -1)
  98. # We disable bonus tokens because it causes corrupt KV cache for
  99. # proposal methods that require KV cache. We can fix it by "prefilling"
  100. # the bonus token in the proposer.
  101. if self._disable_bonus_tokens:
  102. output_with_bonus_tokens[:, -1] = -1
  103. # Fill the recovered token ids.
  104. output.mul_(~after_false_mask).add_(
  105. substitute_token_ids.mul(after_false_mask))
  106. self.num_accepted_tokens += accepted.sum()
  107. self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
  108. self.num_draft_tokens += batch_size * k
  109. return output_with_bonus_tokens
  110. def _raise_if_incorrect_input(
  111. self,
  112. target_probs: torch.Tensor,
  113. draft_token_ids: torch.Tensor,
  114. bonus_token_ids: torch.Tensor,
  115. draft_probs: Optional[torch.Tensor] = None,
  116. ) -> None:
  117. self._raise_if_incorrect_shape(target_probs, draft_token_ids,
  118. bonus_token_ids, draft_probs)
  119. self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
  120. bonus_token_ids, draft_probs)
  121. self._raise_if_inconsistent_device(target_probs, draft_token_ids,
  122. bonus_token_ids, draft_probs)
  123. self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
  124. draft_token_ids, bonus_token_ids)
  125. def _raise_if_incorrect_shape(
  126. self,
  127. target_probs: torch.Tensor,
  128. draft_token_ids: torch.Tensor,
  129. bonus_token_ids: torch.Tensor,
  130. draft_probs: Optional[torch.Tensor] = None,
  131. ) -> None:
  132. (target_batch_size, num_target_probs,
  133. target_vocab_size) = target_probs.shape
  134. # validate the shape of draft token ids.
  135. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
  136. assert draft_token_ids_batch_size == target_batch_size
  137. assert num_draft_token_ids == num_target_probs
  138. # validate the shape of bonus token ids
  139. bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
  140. assert bonus_batch_size == target_batch_size
  141. assert num_bonus_tokens == self._num_bonus_tokens
  142. # validate the shape of draft probs if it is set
  143. if draft_probs is not None:
  144. (draft_batch_size, num_draft_probs,
  145. draft_vocab_size) = draft_probs.shape
  146. assert draft_batch_size == target_batch_size
  147. assert num_draft_probs == num_target_probs
  148. assert (draft_vocab_size == target_vocab_size
  149. ), f"{draft_vocab_size=} {target_vocab_size=}"
  150. def _raise_if_incorrect_dtype(
  151. self,
  152. target_probs: torch.Tensor,
  153. draft_token_ids: torch.Tensor,
  154. bonus_token_ids: torch.Tensor,
  155. draft_probs: Optional[torch.Tensor] = None,
  156. ) -> None:
  157. assert target_probs.dtype == self.probs_dtype
  158. assert draft_token_ids.dtype == self.token_id_dtype
  159. assert bonus_token_ids.dtype == self.token_id_dtype
  160. if draft_probs is not None:
  161. assert draft_probs.dtype == self.probs_dtype
  162. def _raise_if_inconsistent_device(
  163. self,
  164. target_probs: torch.Tensor,
  165. draft_token_ids: torch.Tensor,
  166. bonus_token_ids: torch.Tensor,
  167. draft_probs: Optional[torch.Tensor] = None,
  168. ) -> None:
  169. devices = [
  170. t.device for t in
  171. [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
  172. if t is not None
  173. ]
  174. assert all([devices[0] == device for device in devices])
  175. def _raise_if_out_of_bounds_vocab(
  176. self,
  177. vocab_size: int,
  178. draft_token_ids: torch.Tensor,
  179. bonus_token_ids: torch.Tensor,
  180. ) -> None:
  181. assert torch.all(bonus_token_ids < vocab_size)
  182. assert torch.all(bonus_token_ids >= 0)
  183. assert torch.all(draft_token_ids < vocab_size)
  184. assert torch.all(draft_token_ids >= 0)
  185. class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
  186. """Base class for samplers used for Speculative Decoding verification
  187. step which are deterministic.
  188. """
  189. @abstractmethod
  190. def forward(
  191. self,
  192. target_probs: torch.Tensor,
  193. bonus_token_ids: torch.Tensor,
  194. draft_probs: torch.Tensor,
  195. draft_token_ids: torch.Tensor,
  196. ) -> torch.Tensor:
  197. raise NotImplementedError
  198. class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
  199. """Base class for samplers used for Speculative Decoding verification
  200. step which are stochastic
  201. """
  202. @abstractmethod
  203. def forward(
  204. self,
  205. target_probs: torch.Tensor,
  206. bonus_token_ids: torch.Tensor,
  207. draft_probs: torch.Tensor,
  208. draft_token_ids: torch.Tensor,
  209. seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
  210. ) -> torch.Tensor:
  211. raise NotImplementedError