|
@@ -147,9 +147,18 @@ class SchedulerOutputs:
|
|
|
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
|
|
|
|
|
def _sort_by_lora_ids(self):
|
|
|
- self.scheduled_seq_groups = sorted(
|
|
|
- self.scheduled_seq_groups,
|
|
|
- key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
|
|
|
+ assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)
|
|
|
+
|
|
|
+ def key_fn(group: ScheduledSequenceGroup):
|
|
|
+ key = (group.seq_group.lora_int_id, group.seq_group.request_id)
|
|
|
+ if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
|
|
|
+ # Sort sequence groups so that all prefills come before all
|
|
|
+ # decodes as required by chunked prefill.
|
|
|
+ return (not group.seq_group.is_prefill(), *key)
|
|
|
+ return key
|
|
|
+
|
|
|
+ self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
|
|
|
+ key=key_fn)
|
|
|
|
|
|
@property
|
|
|
def lora_requests(self) -> Set[LoRARequest]:
|