ngram_worker.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import weakref
  2. from typing import List, Optional, Set, Tuple
  3. import torch
  4. from aphrodite.common.sequence import ExecuteModelRequest
  5. from aphrodite.modeling.layers.sampler import SamplerOutput
  6. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  7. from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
  8. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  9. class NGramWorker(NonLLMProposerWorkerBase):
  10. """NGramWorker provides a light drafter without need for model.
  11. Current NGramWorker only implements prompt lookup decoding,
  12. and in future we may also do RAG type drafter and other scenarios
  13. which don't rely on LLM model to give proposals.
  14. """
  15. def __init__(self, *args, **kwargs):
  16. # Get local_rank/vocab_size from kwargs attribute
  17. self.local_rank = kwargs["local_rank"]
  18. self.vocab_size = kwargs["model_config"].get_vocab_size()
  19. # Lazy initialization list.
  20. self._proposer: Top1Proposer
  21. def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
  22. ngram_prompt_lookup_max: int):
  23. # Search valid candidate window between
  24. # ngram_prompt_lookup_min/ngram_prompt_lookup_max
  25. self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
  26. self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
  27. def init_device(self):
  28. self.device = torch.device(f"cuda:{self.local_rank}")
  29. self.load_model = lambda *args, **kwargs: None
  30. # Current NGramWorker only supports Top1Proposer
  31. self._proposer = Top1Proposer(
  32. weakref.proxy(self), # type: ignore[arg-type]
  33. device=self.device,
  34. vocab_size=self.vocab_size,
  35. )
  36. def sampler_output(
  37. self,
  38. execute_model_req: ExecuteModelRequest,
  39. sample_len: int,
  40. # Unused parameter. NGramWorker does not use the KV Cache and
  41. # therefore does not need this parameter.
  42. seq_ids_with_bonus_token_in_last_step: Set[int],
  43. ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
  44. """NGram match algo to pick proposal candidate. Returns the list of
  45. sampler output, one per SequenceGroupMetadata.
  46. For ngram worker, we already done needed transposed internal, so the
  47. indicator pass to sampler_output_to_torch shall be False.
  48. """
  49. self._raise_if_unsupported(execute_model_req)
  50. has_spec_out = False
  51. token_id_list: List[Optional[torch.Tensor]] = []
  52. token_prob_list: List[Optional[torch.Tensor]] = []
  53. for idx, seq_group_metadata in enumerate(
  54. execute_model_req.seq_group_metadata_list):
  55. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  56. input_ids = torch.as_tensor(seq_data.get_token_ids(),
  57. dtype=torch.long,
  58. device=self.device)
  59. input_length = seq_data.get_len()
  60. for ngram_size in range(
  61. min(self.ngram_prompt_lookup_max, input_length - 1),
  62. self.ngram_prompt_lookup_min - 1,
  63. -1,
  64. ):
  65. ngram_tensor = input_ids[-ngram_size:]
  66. if ngram_size == 1:
  67. # Do not match itself and do not use unfold and all
  68. matches = (input_ids[:-1] == ngram_tensor)
  69. else:
  70. windows = input_ids.unfold(dimension=0,
  71. size=ngram_size,
  72. step=1)
  73. # Do not match itself
  74. matches = (windows[:-1] == ngram_tensor).all(dim=-1)
  75. # first_match includes "values" (bool), indicating whether
  76. # the match is found, and "indices", indicating the index
  77. # of the first match.
  78. # Note that "first_match.values.item()" triggers GPU-CPU
  79. # sync so it is a bit inefficient, but we have not found
  80. # a better way to do this.
  81. first_match = matches.max(dim=-1)
  82. if first_match.values.item():
  83. proposal_start_idx = first_match.indices.add_(ngram_size)
  84. spec_indices = (
  85. proposal_start_idx).repeat(sample_len) + torch.arange(
  86. sample_len, device=self.device)
  87. spec_indices.clamp_(max=input_ids.shape[-1] - 1)
  88. res = input_ids.gather(dim=-1, index=spec_indices)
  89. token_id_list.append(res)
  90. token_prob_list.append(
  91. torch.nn.functional.one_hot(
  92. res,
  93. num_classes=self.vocab_size).to(torch.float32))
  94. has_spec_out = True
  95. break
  96. else:
  97. token_id_list.append(None)
  98. token_prob_list.append(None)
  99. if not has_spec_out:
  100. return None, False
  101. outputs: List[Optional[SamplerOutput]] = []
  102. for idx in range(len(execute_model_req.seq_group_metadata_list)):
  103. if token_id_list[idx] is None:
  104. outputs.append(None)
  105. else:
  106. outputs.append(
  107. SamplerOutput(
  108. outputs=None,
  109. sampled_token_probs=token_prob_list[idx],
  110. logprobs=torch.zeros((sample_len, self.vocab_size),
  111. dtype=torch.float32,
  112. device=self.device),
  113. sampled_token_ids=token_id_list[idx],
  114. ))
  115. return outputs, False
  116. def get_spec_proposals(
  117. self,
  118. execute_model_req: ExecuteModelRequest,
  119. # Unused parameter. NGramWorker does not use the KV Cache and
  120. # therefore does not need this parameter.
  121. seq_ids_with_bonus_token_in_last_step: Set[int],
  122. ) -> SpeculativeProposals:
  123. """Produce speculations given an input batch of sequences. The number of
  124. speculative tokens per sequence is determined by max_proposal_len.
  125. """
  126. return self._proposer.get_spec_proposals(
  127. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  128. def _raise_if_unsupported(
  129. self,
  130. execute_model_req: ExecuteModelRequest,
  131. ) -> None:
  132. """NGramWorker does not yet implement support for cache swap
  133. operations or beam search.
  134. """
  135. if any([
  136. execute_model_req.blocks_to_swap_in,
  137. execute_model_req.blocks_to_swap_out,
  138. execute_model_req.blocks_to_copy
  139. ]):
  140. raise NotImplementedError(
  141. "NGramWorker does not support cache operations")
  142. if any(
  143. len(seq_group_metadata.seq_data.keys()) != 1
  144. for seq_group_metadata in
  145. execute_model_req.seq_group_metadata_list):
  146. raise NotImplementedError(
  147. "NGramWorker does not support beam search.")