Sfoglia il codice sorgente

chore: optimize scheduler and remove policy

AlpinDale 6 mesi fa
parent
commit
9866af1626
2 ha cambiato i file con 50 aggiunte e 124 eliminazioni
  1. 0 45
      aphrodite/processing/policy.py
  2. 50 79
      aphrodite/processing/scheduler.py

+ 0 - 45
aphrodite/processing/policy.py

@@ -1,45 +0,0 @@
-from collections import deque
-from typing import Deque
-
-from aphrodite.common.sequence import SequenceGroup
-
-
-class Policy:
-
-    def get_priority(
-        self,
-        now: float,
-        seq_group: SequenceGroup,
-    ) -> float:
-        raise NotImplementedError
-
-    def sort_by_priority(
-        self,
-        now: float,
-        seq_groups: Deque[SequenceGroup],
-    ) -> Deque[SequenceGroup]:
-        return deque(
-            sorted(
-                seq_groups,
-                key=lambda seq_group: self.get_priority(now, seq_group),
-                reverse=True,
-            ))
-
-
-class FCFS(Policy):
-
-    def get_priority(
-        self,
-        now: float,
-        seq_group: SequenceGroup,
-    ) -> float:
-        return now - seq_group.metrics.arrival_time
-
-
-class PolicyFactory:
-
-    _POLICY_REGISTRY = {'fcfs': FCFS}
-
-    @classmethod
-    def get_policy(cls, policy_name: str, **kwargs) -> Policy:
-        return cls._POLICY_REGISTRY[policy_name](**kwargs)

+ 50 - 79
aphrodite/processing/scheduler.py

@@ -13,7 +13,6 @@ from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
                                        SequenceGroupMetadata, SequenceStatus)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
-from aphrodite.processing.policy import Policy, PolicyFactory
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 # Test-only. If configured, decode is preempted with
@@ -314,7 +313,6 @@ class Scheduler:
         # can and must be released after the current step.
         # This is used to evict the finished requests from the Mamba cache.
         self._finished_requests_ids: List[str] = list()
-
         # Time at previous scheduling step
         self.prev_time = 0.0
         # Did we schedule a prompt at previous step?
@@ -345,6 +343,16 @@ class Scheduler:
         # Add sequence groups to the waiting queue.
         self.waiting.append(seq_group)
 
+    def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
+        # Add sequence groups to the running queue.
+        # Only for testing purposes.
+        self.running.append(seq_group)
+
+    def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
+        # Add sequence groups to the swapped queue.
+        # Only for testing purposes.
+        self.swapped.append(seq_group)
+
     def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
         """Aborts a sequence group with the given ID.
 
@@ -398,32 +406,26 @@ class Scheduler:
 
     def _schedule_running(
         self,
-        running_queue: deque,
         budget: SchedulingBudget,
         curr_loras: Optional[Set[int]],
-        policy: Policy,
         enable_chunking: bool = False,
-    ) -> Tuple[deque, SchedulerRunningOutputs]:
+    ) -> SchedulerRunningOutputs:
         """Schedule sequence groups that are running.
 
         Running queue should include decode and chunked prefill requests.
 
         Args:
-            running_queue: The queue that contains running requests (i.e.,
-                decodes). The given arguments are NOT in-place modified.
             budget: The scheduling budget. The argument is in-place updated
                 when any decodes are preempted.
             curr_loras: Currently batched lora request ids. The argument is
                 in-place updated when any decodes are preempted.
-            policy: The sorting policy to sort running_queue.
             enable_chunking: If True, seq group can be chunked and only a
                 chunked number of tokens are scheduled  if
                 `budget.num_batched_tokens` has not enough capacity to schedule
                 all tokens.
     
         Returns:
-            A tuple of remaining running queue (should be always 0) after
-            scheduling and SchedulerRunningOutputs.
+            SchedulerRunningOutputs.
         """
         # Blocks that need to be swapped or copied before model execution.
         blocks_to_swap_out: List[Tuple[int, int]] = []
@@ -436,10 +438,9 @@ class Scheduler:
 
         # NOTE: Preemption happens only when there is no available slot
         # to keep all the sequence groups in the RUNNING state.
-        # In this case, the policy is responsible for deciding which sequence
-        # groups to preempt.
-        now = time.time()
-        running_queue = policy.sort_by_priority(now, running_queue)
+
+        running_queue = self.running
+
         while running_queue:
             seq_group = running_queue[0]
             num_running_tokens = self._get_num_new_tokens(
@@ -455,6 +456,7 @@ class Scheduler:
                 num_running_seqs = seq_group.get_max_num_running_seqs()
                 budget.subtract_num_seqs(seq_group.request_id,
                                          num_running_seqs)
+
                 if (curr_loras is not None and seq_group.lora_int_id > 0
                         and seq_group.lora_int_id in curr_loras):
                     curr_loras.remove(seq_group.lora_int_id)
@@ -502,7 +504,7 @@ class Scheduler:
                 if curr_loras is not None and seq_group.lora_int_id > 0:
                     curr_loras.add(seq_group.lora_int_id)
 
-        return running_queue, SchedulerRunningOutputs(
+        return SchedulerRunningOutputs(
             decode_seq_groups=decode_seq_groups,
             prefill_seq_groups=prefill_seq_groups,
             preempted=preempted,
@@ -514,12 +516,10 @@ class Scheduler:
 
     def _schedule_swapped(
         self,
-        swapped_queue: deque,
         budget: SchedulingBudget,
         curr_loras: Optional[Set[int]],
-        policy: Policy,
         enable_chunking: bool = False,
-    ) -> Tuple[deque, SchedulerSwappedInOutputs]:
+    ) -> SchedulerSwappedInOutputs:
         """Schedule sequence groups that are swapped out.
 
         It schedules swapped requests as long as it fits `budget` and
@@ -527,20 +527,16 @@ class Scheduler:
         `budget` and `curr_loras` are updated based on scheduled seq_groups.
 
         Args:
-            swapped_queue: The queue that contains swapped out requests.
-                The given arguments are NOT in-place modified.
             budget: The scheduling budget. The argument is in-place updated
                 when any requests are swapped in.
             curr_loras: Currently batched lora request ids. The argument is
                 in-place updated when any requests are swapped in.
-            policy: The sorting policy to sort swapped_queue.
             enable_chunking: If True, seq group can be chunked and only a
                 chunked number of tokens are scheduled  if
                 `budget.num_batched_tokens` has not enough capacity to schedule
                 all tokens.
 
         Returns:
-            A tuple of remaining swapped_queue after scheduling and
             SchedulerSwappedInOutputs.
         """
         # Blocks that need to be swapped or copied before model execution.
@@ -548,10 +544,10 @@ class Scheduler:
         blocks_to_copy: List[Tuple[int, int]] = []
         decode_seq_groups: List[ScheduledSequenceGroup] = []
         prefill_seq_groups: List[ScheduledSequenceGroup] = []
-        now = time.time()
-        swapped_queue = policy.sort_by_priority(now, swapped_queue)
         infeasible_seq_groups: List[SequenceGroup] = []
 
+        swapped_queue = self.swapped
+
         leftover_swapped: Deque[SequenceGroup] = deque()
         while swapped_queue:
             seq_group = swapped_queue[0]
@@ -615,7 +611,7 @@ class Scheduler:
 
         swapped_queue.extendleft(leftover_swapped)
 
-        return swapped_queue, SchedulerSwappedInOutputs(
+        return SchedulerSwappedInOutputs(
             decode_seq_groups=decode_seq_groups,
             prefill_seq_groups=prefill_seq_groups,
             blocks_to_swap_in=blocks_to_swap_in,
@@ -642,11 +638,10 @@ class Scheduler:
 
     def _schedule_prefills(
         self,
-        waiting_queue: deque,
         budget: SchedulingBudget,
         curr_loras: Optional[Set[int]],
         enable_chunking: bool = False,
-    ) -> Tuple[deque, SchedulerPrefillOutputs]:
+    ) -> SchedulerPrefillOutputs:
         """Schedule sequence groups that are in prefill stage.
 
         Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
@@ -658,8 +653,6 @@ class Scheduler:
         `budget` and `curr_loras` are updated based on scheduled seq_groups.
 
         Args:
-            waiting_queue: The queue that contains prefill requests.
-                The given arguments are NOT in-place modified.
             budget: The scheduling budget. The argument is in-place updated
                 when any requests are scheduled.
             curr_loras: Currently batched lora request ids. The argument is
@@ -670,14 +663,12 @@ class Scheduler:
                 all tokens.
 
         Returns:
-            A tuple of remaining waiting_queue after scheduling and
             SchedulerSwappedInOutputs.
         """
         ignored_seq_groups: List[SequenceGroup] = []
         seq_groups: List[SequenceGroup] = []
-        # We don't sort waiting queue because we assume it is sorted.
-        # Copy the queue so that the input queue is not modified.
-        waiting_queue = deque([s for s in waiting_queue])
+
+        waiting_queue = self.waiting
 
         leftover_waiting_sequences: Deque[SequenceGroup] = deque()
         while self._passed_delay(time.time()) and waiting_queue:
@@ -697,8 +688,8 @@ class Scheduler:
             prompt_limit = self._get_prompt_limit(seq_group)
             if num_new_tokens > prompt_limit:
                 logger.warning(
-                    f"Input prompt ({num_new_tokens}) tokens) is too long"
-                    f" and exceeds limit of {prompt_limit}")
+                    "Input prompt (%d tokens) is too long"
+                    " and exceeds limit of %d", num_new_tokens, prompt_limit)
                 for seq in waiting_seqs:
                     seq.status = SequenceStatus.FINISHED_IGNORED
                 ignored_seq_groups.append(seq_group)
@@ -711,8 +702,9 @@ class Scheduler:
                 break
             elif can_allocate == AllocStatus.NEVER:
                 logger.warning(
-                    f"Input prompt ({num_new_tokens} tokens) is too long"
-                    f" and exceeds the capacity of block_manager")
+                    "Input prompt (%d tokens) is too long"
+                    " and exceeds the capacity of block_manager",
+                    num_new_tokens)
                 for seq in waiting_seqs:
                     seq.status = SequenceStatus.FINISHED_IGNORED
                 ignored_seq_groups.append(seq_group)
@@ -755,7 +747,7 @@ class Scheduler:
         if len(seq_groups) > 0:
             self.prev_prompt = True
 
-        return waiting_queue, SchedulerPrefillOutputs(
+        return SchedulerPrefillOutputs(
             seq_groups=seq_groups,
             ignored_seq_groups=ignored_seq_groups,
             num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
@@ -782,54 +774,43 @@ class Scheduler:
             seq_group.lora_int_id for seq_group in self.running
             if seq_group.lora_int_id > 0) if self.lora_enabled else None
 
-        remaining_waiting, prefills = (self.waiting,
-                                       SchedulerPrefillOutputs.create_empty())
-        remaining_running, running_scheduled = (
-            self.running, SchedulerRunningOutputs.create_empty())
-        remaining_swapped, swapped_in = (
-            self.swapped, SchedulerSwappedInOutputs.create_empty())
+        prefills = SchedulerPrefillOutputs.create_empty()
+        running_scheduled = SchedulerRunningOutputs.create_empty()
+        swapped_in = SchedulerSwappedInOutputs.create_empty()
 
         # If any requests are swapped, prioritized swapped requests.
         if not self.swapped:
-            remaining_waiting, prefills = self._schedule_prefills(
-                self.waiting, budget, curr_loras, enable_chunking=False)
+            prefills = self._schedule_prefills(budget,
+                                               curr_loras,
+                                               enable_chunking=False)
 
-        fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
         # Don't schedule decodes if prefills are scheduled.
         # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
         # only contains decode requests, not chunked prefills.
         if len(prefills.seq_groups) == 0:
-            remaining_running, running_scheduled = self._schedule_running(
-                self.running,
-                budget,
-                curr_loras,
-                fcfs_policy,
-                enable_chunking=False)
+            running_scheduled = self._schedule_running(budget,
+                                                       curr_loras,
+                                                       enable_chunking=False)
 
             # If any sequence group is preempted, do not swap in any sequence
             # group. because it means there's no slot for new running requests.
             if len(running_scheduled.preempted) + len(
                     running_scheduled.swapped_out) == 0:
-                remaining_swapped, swapped_in = self._schedule_swapped(
-                    self.swapped, budget, curr_loras, fcfs_policy)
+                swapped_in = self._schedule_swapped(budget, curr_loras)
 
         assert (budget.num_batched_tokens <=
                 self.scheduler_config.max_num_batched_tokens)
         assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
 
         # Update waiting requests.
-        self.waiting = remaining_waiting
         self.waiting.extendleft(running_scheduled.preempted)
         # Update new running requests.
-        self.running = remaining_running
         self.running.extend([s.seq_group for s in prefills.seq_groups])
         self.running.extend(
             [s.seq_group for s in running_scheduled.decode_seq_groups])
         self.running.extend(
             [s.seq_group for s in swapped_in.decode_seq_groups])
         # Update swapped requests.
-        # Update swapped requests.
-        self.swapped = remaining_swapped
         self.swapped.extend(running_scheduled.swapped_out)
         preempted = (len(running_scheduled.preempted) +
                      len(running_scheduled.swapped_out))
@@ -875,42 +856,32 @@ class Scheduler:
         )
         curr_loras: Set[int] = set()
 
-        remaining_waiting, prefills = (self.waiting,
-                                       SchedulerPrefillOutputs.create_empty())
-        remaining_running, running_scheduled = (
-            self.running, SchedulerRunningOutputs.create_empty())
-        remaining_swapped, swapped_in = (
-            self.swapped, SchedulerSwappedInOutputs.create_empty())
+        prefills = SchedulerPrefillOutputs.create_empty()
+        swapped_in = SchedulerSwappedInOutputs.create_empty()
 
         # Decoding should be always scheduled first by fcfs.
-        fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
-        remaining_running, running_scheduled = self._schedule_running(
-            self.running,
-            budget,
-            curr_loras,
-            fcfs_policy,
-            enable_chunking=True)
+        running_scheduled = self._schedule_running(budget,
+                                                   curr_loras,
+                                                   enable_chunking=True)
 
         # Schedule swapped out requests.
         # If preemption happens, it means we don't have space for swap-in.
         if len(running_scheduled.preempted) + len(
                 running_scheduled.swapped_out) == 0:
-            remaining_swapped, swapped_in = self._schedule_swapped(
-                self.swapped, budget, curr_loras, fcfs_policy)
+            swapped_in = self._schedule_swapped(budget, curr_loras)
 
         # Schedule new prefills.
-        remaining_waiting, prefills = self._schedule_prefills(
-            self.waiting, budget, curr_loras, enable_chunking=True)
+        prefills = self._schedule_prefills(budget,
+                                           curr_loras,
+                                           enable_chunking=True)
 
         assert (budget.num_batched_tokens <=
                 self.scheduler_config.max_num_batched_tokens)
         assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
 
         # Update waiting requests.
-        self.waiting = remaining_waiting
         self.waiting.extendleft(running_scheduled.preempted)
         # Update new running requests.
-        self.running = remaining_running
         self.running.extend([s.seq_group for s in prefills.seq_groups])
         self.running.extend(
             [s.seq_group for s in running_scheduled.decode_seq_groups])
@@ -921,7 +892,6 @@ class Scheduler:
         self.running.extend(
             [s.seq_group for s in swapped_in.prefill_seq_groups])
         # Update swapped requests.
-        self.swapped = remaining_swapped
         self.swapped.extend(running_scheduled.swapped_out)
         return SchedulerOutputs(
             scheduled_seq_groups=(prefills.seq_groups +
@@ -1079,6 +1049,7 @@ class Scheduler:
         blocks_to_copy: List[Tuple[int, int]],
     ) -> None:
         """Appends new slots to the sequences in the given sequence group.
+
         Args:
             seq_group (SequenceGroup): The sequence group containing the
                 sequences to append slots to.