proposer_worker_base.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Set, Tuple
  3. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  4. from aphrodite.spec_decode.interfaces import SpeculativeProposer
  5. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  6. class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
  7. """Interface for proposer workers"""
  8. @abstractmethod
  9. def sampler_output(
  10. self,
  11. execute_model_req: ExecuteModelRequest,
  12. sample_len: int,
  13. # A set containing all sequence IDs that were assigned bonus tokens
  14. # in their last forward pass. This set is used to backfill the KV cache
  15. # with the key-value pairs of the penultimate token in the sequences.
  16. # This parameter is only used by the MultiStepWorker, which relies on
  17. # the KV cache for token generation. It is not used by workers that
  18. # do not utilize the KV cache.
  19. seq_ids_with_bonus_token_in_last_step: Set[int]
  20. ) -> Tuple[Optional[List[SamplerOutput]], bool]:
  21. raise NotImplementedError
  22. def set_include_gpu_probs_tensor(self) -> None:
  23. """Implementation optional"""
  24. pass
  25. def set_should_modify_greedy_probs_inplace(self) -> None:
  26. """Implementation optional"""
  27. pass
  28. class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
  29. """Proposer worker which does not use a model with kvcache"""
  30. def execute_model(
  31. self,
  32. execute_model_req: Optional[ExecuteModelRequest] = None
  33. ) -> List[SamplerOutput]:
  34. """get_spec_proposals is used to get the proposals"""
  35. return []
  36. def determine_num_available_blocks(self) -> Tuple[int, int]:
  37. """This is never called on the proposer, only the target model"""
  38. raise NotImplementedError
  39. def initialize_cache(self, num_gpu_blocks: int,
  40. num_cpu_blocks: int) -> None:
  41. pass
  42. def get_cache_block_size_bytes(self) -> int:
  43. return 0