Răsfoiți Sursa

chore: add support for up to 2048 block size (#715)

AlpinDale 6 luni în urmă
părinte
comite
008e646c7e

+ 1 - 1
aphrodite/engine/args_tools.py

@@ -514,7 +514,7 @@ class EngineArgs:
             "--block-size",
             type=int,
             default=EngineArgs.block_size,
-            choices=[8, 16, 32],
+            choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
             help="Category: Cache Options\n"
             "token block size",
         )

+ 2 - 3
aphrodite/executor/neuron_executor.py

@@ -97,9 +97,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
         self,
         execute_model_req: ExecuteModelRequest,
     ) -> List[SamplerOutput]:
-        output = await make_async(
-            self.driver_worker.execute_model
-        )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
+        output = await make_async(self.driver_worker.execute_model
+                                  )(execute_model_req=execute_model_req, )
         return output
 
     async def check_health_async(self) -> None:

+ 2 - 0
aphrodite/task_handler/neuron_model_runner.py

@@ -55,6 +55,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
+        **kwargs,
     ):
         self.model_config = model_config
         self.parallel_config = parallel_config
@@ -197,6 +198,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
         virtual_engine: int = 0,
         finished_requests_ids: Optional[List[str]] = None
     ) -> ModelInputForNeuron:
+        multi_modal_kwargs = None
         # NOTE: We assume that all sequences in the group are all prompts or
         # all decodes.
         is_prompt = seq_group_metadata_list[0].is_prompt

+ 3 - 0
aphrodite/task_handler/neuron_worker.py

@@ -90,6 +90,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
         return WorkerInput(num_seq_groups=len(
             execute_model_req.seq_group_metadata_list), )
 
+    def execute_worker(self, worker_input: WorkerInput) -> None:
+        pass
+
     def get_cache_block_size_bytes(self) -> int:
         """Determine the size in bytes of a cache block.