proposer_worker_base.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Set, Tuple
  3. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  4. from aphrodite.lora.request import LoRARequest
  5. from aphrodite.spec_decode.interfaces import SpeculativeProposer
  6. from aphrodite.task_handler.worker_base import WorkerBase
  7. class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
  8. """Interface for proposer workers"""
  9. @abstractmethod
  10. def sampler_output(
  11. self,
  12. execute_model_req: ExecuteModelRequest,
  13. sample_len: int,
  14. ) -> Tuple[Optional[List[SamplerOutput]], bool]:
  15. raise NotImplementedError
  16. def set_include_gpu_probs_tensor(self):
  17. """Implementation optional"""
  18. pass
  19. def add_lora(self, lora_request: LoRARequest) -> bool:
  20. raise ValueError(f"{type(self)} does not support LoRA")
  21. def remove_lora(self, lora_id: int) -> bool:
  22. raise ValueError(f"{type(self)} does not support LoRA")
  23. def list_loras(self) -> Set[int]:
  24. raise ValueError(f"{type(self)} does not support LoRA")
  25. class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
  26. """Proposer worker which does not use a model with kvcache"""
  27. def execute_model(
  28. self,
  29. execute_model_req: Optional[ExecuteModelRequest] = None
  30. ) -> List[SamplerOutput]:
  31. """get_spec_proposals is used to get the proposals"""
  32. return []
  33. def determine_num_available_blocks(self) -> Tuple[int, int]:
  34. """This is never called on the proposer, only the target model"""
  35. raise NotImplementedError
  36. def initialize_cache(self, num_gpu_blocks: int,
  37. num_cpu_blocks: int) -> None:
  38. pass
  39. def get_cache_block_size_bytes(self) -> int:
  40. return 0
  41. def add_lora(self, lora_request: LoRARequest) -> bool:
  42. raise ValueError(f"{type(self)} does not support LoRA")
  43. def remove_lora(self, lora_id: int) -> bool:
  44. raise ValueError(f"{type(self)} does not support LoRA")
  45. def list_loras(self) -> Set[int]:
  46. raise ValueError(f"{type(self)} does not support LoRA")