target_model_runner.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import List, Optional
  2. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  3. LoRAConfig, ModelConfig, MultiModalConfig,
  4. ParallelConfig, PromptAdapterConfig,
  5. SchedulerConfig)
  6. from aphrodite.common.sequence import SequenceGroupMetadata
  7. from aphrodite.task_handler.model_runner import (
  8. ModelInputForGPUWithSamplingMetadata, ModelRunner)
  9. class TargetModelRunner(ModelRunner):
  10. """Specialized model runner for speculative decoding target model.
  11. In speculative decoding, the log probabilities selected finally may not
  12. be the same ones as selected by the target model sampling. This means
  13. that the time spent in the log probability calculation of the target model
  14. is time wasted, since we calculate log probabilities after deciding which
  15. tokens are accepted. For this reason disabling log probabilities in the
  16. target model will make decode faster. The model runner sets the
  17. SamplingMetadata parameters according to whether log probabilities are
  18. requested or not.
  19. """
  20. def __init__(self,
  21. model_config: ModelConfig,
  22. parallel_config: ParallelConfig,
  23. scheduler_config: SchedulerConfig,
  24. device_config: DeviceConfig,
  25. cache_config: CacheConfig,
  26. load_config: LoadConfig,
  27. lora_config: Optional[LoRAConfig],
  28. kv_cache_dtype: Optional[str] = "auto",
  29. is_driver_worker: bool = False,
  30. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  31. multimodal_config: Optional[MultiModalConfig] = None,
  32. return_hidden_states: bool = False,
  33. **kwargs):
  34. # An internal boolean member variable to indicate if token log
  35. # probabilities are needed or not.
  36. self.disable_logprobs = True
  37. super().__init__(
  38. model_config=model_config,
  39. parallel_config=parallel_config,
  40. scheduler_config=scheduler_config,
  41. device_config=device_config,
  42. cache_config=cache_config,
  43. load_config=load_config,
  44. lora_config=lora_config,
  45. kv_cache_dtype=kv_cache_dtype,
  46. is_driver_worker=is_driver_worker,
  47. multimodal_config=multimodal_config,
  48. prompt_adapter_config=prompt_adapter_config,
  49. return_hidden_states=return_hidden_states,
  50. **kwargs,
  51. )
  52. def prepare_model_input(
  53. self,
  54. seq_group_metadata_list: List[SequenceGroupMetadata],
  55. virtual_engine: int = 0,
  56. finished_requests_ids: Optional[List[str]] = None
  57. ) -> ModelInputForGPUWithSamplingMetadata:
  58. model_input: ModelInputForGPUWithSamplingMetadata = super(
  59. ).prepare_model_input(seq_group_metadata_list, virtual_engine,
  60. finished_requests_ids)
  61. # If token log probabilities is disabled then skip generating sampler
  62. # CPU output. We directly serialize the GPU sampled_token_id tensors
  63. # as needed. If log probabilities is enabled then synchronize all the
  64. # sampling related tensors which includes the logprobs tensors.
  65. model_input.sampling_metadata.skip_sampler_cpu_output = (
  66. self.disable_logprobs)
  67. return model_input