ngram_worker.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  4. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  5. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  6. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  7. class NGramWorker(LoraNotSupportedWorkerBase):
  8. """NGramWorker provides a light drafter without need for model.
  9. Current NGramWorker only implement prompt lookup decoding,
  10. and in future we may also do RAG type drafter and other scenerios
  11. which don't rely on LLM model to give proposals.
  12. """
  13. def __init__(self, *args, **kwargs):
  14. # Get local_rank/vocab_size from kwargs attribute
  15. self.local_rank = kwargs["local_rank"]
  16. self.vocab_size = kwargs["model_config"].get_vocab_size()
  17. # Lazy initialization list.
  18. self._proposer: Top1Proposer
  19. def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
  20. ngram_prompt_lookup_max: int):
  21. # Search valid candidate window between
  22. # ngram_prompt_lookup_min/ngram_prompt_lookup_max
  23. self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
  24. self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
  25. def init_device(self):
  26. self.device = torch.device(f"cuda:{self.local_rank}")
  27. self.load_model = lambda *args, **kwargs: None
  28. # Current only support Top1Proposer
  29. self._proposer = Top1Proposer(
  30. self,
  31. device=self.device,
  32. vocab_size=self.vocab_size,
  33. )
  34. def set_include_gpu_probs_tensor(self):
  35. # NGram don't need gpu sampler
  36. pass
  37. def execute_model(
  38. self,
  39. seq_group_metadata_list: List[SequenceGroupMetadata],
  40. blocks_to_swap_in: Optional[Dict[int, int]],
  41. blocks_to_swap_out: Optional[Dict[int, int]],
  42. blocks_to_copy: Optional[Dict[int, List[int]]],
  43. ) -> None:
  44. """NGram doesn't depend on model execution, just pass this function"""
  45. pass
  46. def determine_num_available_blocks(self) -> None:
  47. """NGram doesn't depend on model execution, no need to check blocks"""
  48. pass
  49. def initialize_cache(self, num_gpu_blocks: int,
  50. num_cpu_blocks: int) -> None:
  51. """As there is no cache need to handle, just pass this function"""
  52. pass
  53. def get_cache_block_size_bytes(self):
  54. """Return the size of a cache block in bytes."""
  55. return 0
  56. def sampler_output(
  57. self,
  58. seq_group_metadata_list: List[SequenceGroupMetadata],
  59. blocks_to_swap_in: Dict[int, int],
  60. blocks_to_swap_out: Dict[int, int],
  61. blocks_to_copy: Dict[int, List[int]],
  62. sample_len: int,
  63. ) -> Tuple[Optional[List[SamplerOutput]], bool]:
  64. """NGram match algo to pick proposal candidate. Returns the list of
  65. sampler output, one per SequenceGroupMetadata.
  66. """
  67. self._raise_if_unsupported(
  68. seq_group_metadata_list,
  69. blocks_to_swap_in,
  70. blocks_to_swap_out,
  71. blocks_to_copy,
  72. )
  73. arr = []
  74. has_spec_out = False
  75. for seq_group_metadata in seq_group_metadata_list:
  76. seq_data = next(iter(seq_group_metadata.seq_data.values()))
  77. input_ids = torch.as_tensor(seq_data.get_token_ids(),
  78. dtype=torch.long,
  79. device=self.device)
  80. input_length = seq_data.get_len()
  81. for ngram_size in range(
  82. min(self.ngram_prompt_lookup_max, input_length - 1),
  83. self.ngram_prompt_lookup_min,
  84. -1,
  85. ):
  86. ngram_tensor = input_ids[-1 * ngram_size:]
  87. windows = input_ids.unfold(dimension=0,
  88. size=ngram_size,
  89. step=1)
  90. matches = (windows == ngram_tensor).all(dim=1)
  91. match_indices = matches.nonzero(as_tuple=True)[0]
  92. if match_indices.size()[0] > 1:
  93. has_spec_out = True
  94. res = seq_data.get_token_ids()
  95. res = res[match_indices[0] + ngram_size:match_indices[0] +
  96. ngram_size + sample_len]
  97. res_len = len(res)
  98. # pad 0 towards output as sample_len tokens required
  99. res += [0] * (sample_len - res_len)
  100. break
  101. else:
  102. # if no candidate found, fill with 0
  103. res = [0] * sample_len
  104. arr.append(res)
  105. if not has_spec_out:
  106. return None, False
  107. outputs = []
  108. token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
  109. indices = token_ids.unsqueeze(2)
  110. token_probs = torch.zeros(
  111. (len(seq_group_metadata_list), sample_len, self.vocab_size),
  112. dtype=torch.float32,
  113. device=self.device,
  114. )
  115. token_probs.scatter_(2, indices, 1)
  116. for i in range(len(seq_group_metadata_list)):
  117. outputs.append(
  118. SamplerOutput(
  119. outputs=None,
  120. sampled_token_probs=token_probs[i],
  121. sampled_token_ids=token_ids[i],
  122. ))
  123. return outputs, False
  124. def get_spec_proposals(
  125. self,
  126. seq_group_metadata_list: List[SequenceGroupMetadata],
  127. blocks_to_swap_in: Dict[int, int],
  128. blocks_to_swap_out: Dict[int, int],
  129. blocks_to_copy: Dict[int, List[int]],
  130. max_proposal_len: int,
  131. ) -> SpeculativeProposals:
  132. """Produce speculations given an input batch of sequences. The number of
  133. speculative tokens per sequence is determined by max_proposal_len.
  134. """
  135. return self._proposer.get_proposals(
  136. seq_group_metadata_list,
  137. blocks_to_swap_in,
  138. blocks_to_swap_out,
  139. blocks_to_copy,
  140. max_proposal_len,
  141. )
  142. def _raise_if_unsupported(
  143. self,
  144. seq_group_metadata_list: List[SequenceGroupMetadata],
  145. blocks_to_swap_in: Dict[int, int],
  146. blocks_to_swap_out: Dict[int, int],
  147. blocks_to_copy: Dict[int, List[int]],
  148. ) -> None:
  149. """NGramWorker does not yet implement support for cache swap
  150. operations or beam search.
  151. """
  152. if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
  153. raise NotImplementedError(
  154. "NGramWorker does not support cache operations")
  155. if any(
  156. len(seq_group_metadata.seq_data.keys()) != 1
  157. for seq_group_metadata in seq_group_metadata_list):
  158. raise NotImplementedError(
  159. "NGramWorker does not support beam search.")