فهرست منبع

feat: support chunked prefill with LoRA (#823)

AlpinDale 3 ماه پیش
والد
کامیت
0a369f9171
3فایلهای تغییر یافته به همراه21 افزوده شده و 9 حذف شده
  1. 2 1
      aphrodite/common/config.py
  2. 12 3
      aphrodite/processing/scheduler.py
  3. 7 5
      aphrodite/task_handler/model_runner.py

+ 2 - 1
aphrodite/common/config.py

@@ -1597,7 +1597,8 @@ class LoRAConfig:
 
     def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
         if scheduler_config.chunked_prefill_enabled:
-            raise ValueError("LoRA is not supported with chunked prefill yet.")
+            logger.warning(
+                "Chunked Prefill with LoRA is not rigorously tested.")
 
     def verify_with_parallel_config(self, parallel_config: ParallelConfig):
         if self.lora_vocab_padding_size % parallel_config.world_size != 0:

+ 12 - 3
aphrodite/processing/scheduler.py

@@ -147,9 +147,18 @@ class SchedulerOutputs:
                 and not self.blocks_to_swap_out and not self.blocks_to_copy)
 
     def _sort_by_lora_ids(self):
-        self.scheduled_seq_groups = sorted(
-            self.scheduled_seq_groups,
-            key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
+        assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)
+
+        def key_fn(group: ScheduledSequenceGroup):
+            key = (group.seq_group.lora_int_id, group.seq_group.request_id)
+            if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
+                # Sort sequence groups so that all prefills come before all
+                # decodes as required by chunked prefill.
+                return (not group.seq_group.is_prefill(), *key)
+            return key
+
+        self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
+                                           key=key_fn)
 
     @property
     def lora_requests(self) -> Set[LoRARequest]:

+ 7 - 5
aphrodite/task_handler/model_runner.py

@@ -569,11 +569,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
             inter_data.lora_requests.add(seq_group_metadata.lora_request)
         query_len = inter_data.query_lens[seq_idx]
         inter_data.lora_index_mapping.append([lora_id] * query_len)
-        inter_data.lora_prompt_mapping.append(
-            [lora_id] *
-            (query_len if seq_group_metadata.sampling_params
-             and seq_group_metadata.sampling_params.prompt_logprobs is not None
-             else 1))
+        sampling_params = seq_group_metadata.sampling_params
+        if sampling_params and sampling_params.prompt_logprobs is not None:
+            inter_data.lora_prompt_mapping.append([lora_id] * query_len)
+        elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
+            inter_data.lora_prompt_mapping.append([lora_id])
+        else:
+            inter_data.lora_prompt_mapping.append([])
 
     def _compute_prompt_adapter_input(
             self, inter_data: InterDataForSeqGroup,