proposer_worker_base.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Set, Tuple
  3. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  4. from aphrodite.lora.request import LoRARequest
  5. from aphrodite.spec_decode.interfaces import SpeculativeProposer
  6. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  7. class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
  8. """Interface for proposer workers"""
  9. @abstractmethod
  10. def sampler_output(
  11. self,
  12. execute_model_req: ExecuteModelRequest,
  13. sample_len: int,
  14. # A set containing all sequence IDs that were assigned bonus tokens
  15. # in their last forward pass. This set is used to backfill the KV cache
  16. # with the key-value pairs of the penultimate token in the sequences.
  17. # This parameter is only used by the MultiStepWorker, which relies on
  18. # the KV cache for token generation. It is not used by workers that
  19. # do not utilize the KV cache.
  20. seq_ids_with_bonus_token_in_last_step: Set[int]
  21. ) -> Tuple[Optional[List[SamplerOutput]], bool]:
  22. raise NotImplementedError
  23. def set_include_gpu_probs_tensor(self) -> None:
  24. """Implementation optional"""
  25. pass
  26. def set_should_modify_greedy_probs_inplace(self) -> None:
  27. """Implementation optional"""
  28. pass
  29. def add_lora(self, lora_request: LoRARequest) -> bool:
  30. raise ValueError(f"{type(self)} does not support LoRA")
  31. def remove_lora(self, lora_id: int) -> bool:
  32. raise ValueError(f"{type(self)} does not support LoRA")
  33. def list_loras(self) -> Set[int]:
  34. raise ValueError(f"{type(self)} does not support LoRA")
  35. class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
  36. """Proposer worker which does not use a model with kvcache"""
  37. def execute_model(
  38. self,
  39. execute_model_req: Optional[ExecuteModelRequest] = None
  40. ) -> List[SamplerOutput]:
  41. """get_spec_proposals is used to get the proposals"""
  42. return []
  43. def determine_num_available_blocks(self) -> Tuple[int, int]:
  44. """This is never called on the proposer, only the target model"""
  45. raise NotImplementedError
  46. def initialize_cache(self, num_gpu_blocks: int,
  47. num_cpu_blocks: int) -> None:
  48. pass
  49. def get_cache_block_size_bytes(self) -> int:
  50. return 0
  51. def add_lora(self, lora_request: LoRARequest) -> bool:
  52. raise ValueError(f"{type(self)} does not support LoRA")
  53. def remove_lora(self, lora_id: int) -> bool:
  54. raise ValueError(f"{type(self)} does not support LoRA")
  55. def list_loras(self) -> Set[int]:
  56. raise ValueError(f"{type(self)} does not support LoRA")