Browse Source

chore: decouple `should_modify_greedy_probs_inplace (#671)

AlpinDale 6 months ago
parent
commit
1394008421

+ 4 - 0
aphrodite/lora/layers.py

@@ -1064,6 +1064,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
     @property
     def include_gpu_probs_tensor(self):
         return self.base_layer.include_gpu_probs_tensor
+    
+    @property
+    def should_modify_greedy_probs_inplace(self):
+        return self.base_layer.should_modify_greedy_probs_inplace
 
     def create_lora_weights(
         self,

+ 2 - 2
aphrodite/modeling/layers/sampler.py

@@ -51,6 +51,7 @@ class Sampler(nn.Module):
         # containing the sampled token ids and probabilities. This is used by
         # speculative decoding.
         self.include_gpu_probs_tensor = False
+        self.should_modify_greedy_probs_inplace = False
 
     def _init_sampling_tensors(
         self,
@@ -222,8 +223,7 @@ class Sampler(nn.Module):
         This is used by speculative decoding, which requires that the sampling
         method be encoded into the probability distribution.
         """
-        # Modify greedy probs if include_gpu_probs_tensor is set.
-        return self.include_gpu_probs_tensor
+        return self.should_modify_greedy_probs_inplace
 
 
 def _get_bin_counts_and_mask(

+ 3 - 0
aphrodite/spec_decode/medusa_worker.py

@@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
     def set_include_gpu_probs_tensor(self):
         pass
 
+    def set_should_modify_greedy_probs_inplace(self):
+        pass
+
     @torch.inference_mode()
     def sampler_output(
         self,

+ 4 - 0
aphrodite/spec_decode/multi_step_worker.py

@@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
         # Need include_gpu_probs_tensor for MultiStepWorker
         self.model_runner.model.sampler.include_gpu_probs_tensor = True
 
+    def set_should_modify_greedy_probs_inplace(self) -> None:
+        self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
+            True)
+
     @torch.inference_mode()
     def sampler_output(
         self,

+ 4 - 0
aphrodite/spec_decode/proposer_worker_base.py

@@ -29,6 +29,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
         """Implementation optional"""
         pass
 
+    def set_should_modify_greedy_probs_inplace(self) -> None:
+        """Implementation optional"""
+        pass
+
     def add_lora(self, lora_request: LoRARequest) -> bool:
         raise ValueError(f"{type(self)} does not support LoRA")
 

+ 6 - 0
aphrodite/spec_decode/smaller_tp_proposer_worker.py

@@ -79,6 +79,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
         # Need include_gpu_probs_tensor for multi_step_worker
         self._worker.set_include_gpu_probs_tensor()
 
+    def set_should_modify_greedy_probs_inplace(self) -> None:
+        if self._is_dummy:
+            return
+
+        self._worker.set_should_modify_greedy_probs_inplace()
+
     def load_model(self) -> None:
         if self._is_dummy:
             return

+ 3 - 0
aphrodite/spec_decode/spec_decode_worker.py

@@ -294,7 +294,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         """
         (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
          ) = True
+        (self.scorer_worker.model_runner.model.sampler.
+         should_modify_greedy_probs_inplace) = True
         self.proposer_worker.set_include_gpu_probs_tensor()
+        self.proposer_worker.set_should_modify_greedy_probs_inplace()
 
     def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of cache blocks to use.