Parcourir la source

chore: multi-step args and sequence modifications (#713)

AlpinDale il y a 6 mois
Parent
commit
577586309d

+ 13 - 1
aphrodite/common/config.py

@@ -923,7 +923,8 @@ class SchedulerConfig:
                  delay_factor: float = 0.0,
                  enable_chunked_prefill: bool = False,
                  embedding_mode: Optional[bool] = False,
-                 preemption_mode: Optional[str] = None) -> None:
+                 preemption_mode: Optional[str] = None,
+                 num_scheduler_steps: int = 1) -> None:
         if max_num_batched_tokens is not None:
             self.max_num_batched_tokens = max_num_batched_tokens
         else:
@@ -952,6 +953,7 @@ class SchedulerConfig:
         self.chunked_prefill_enabled = enable_chunked_prefill
         self.embedding_mode = embedding_mode
         self.preemption_mode = preemption_mode
+        self.num_scheduler_steps = num_scheduler_steps
 
         self._verify_args()
 
@@ -978,6 +980,16 @@ class SchedulerConfig:
                 f"({self.num_lookahead_slots}) must be greater than or "
                 "equal to 0.")
 
+        if self.num_scheduler_steps < 1:
+            raise ValueError(
+                "num_scheduler_steps "
+                f"({self.num_scheduler_steps}) must be greater than or "
+                "equal to 1.")
+
+    @property
+    def is_multi_step(self) -> bool:
+        return self.num_scheduler_steps > 1
+
 
 class DeviceConfig:
 

+ 56 - 1
aphrodite/common/sequence.py

@@ -7,6 +7,7 @@ from collections import defaultdict
 from dataclasses import dataclass, field
 from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast
 
+import numpy
 import torch
 
 from aphrodite.common.pooling_params import PoolingParams
@@ -474,6 +475,19 @@ class Sequence:
                 f"num_blocks={self.n_blocks}, ")
 
 
+@dataclass
+class SequenceGroupState:
+    """Mutable state tied to a specific sequence group"""
+
+    # for multi-step decoding
+    num_steps: int = 1
+    current_step: int = 0
+
+    @property
+    def remaining_steps(self) -> int:
+        return self.num_steps - self.current_step
+
+
 class SequenceGroup:
     """A group of sequences that are generated from the same prompt.
 
@@ -516,6 +530,7 @@ class SequenceGroup:
                                       time_in_queue=None)
         self.lora_request = lora_request
         self.prompt_logprobs: Optional[PromptLogprobs] = None
+        self.state = SequenceGroupState()
         self.embeddings = embeddings
         self.pooling_params = pooling_params
         self.prompt_adapter_request = prompt_adapter_request
@@ -569,6 +584,10 @@ class SequenceGroup:
         return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
                          if self.prompt_adapter_request else 0
 
+    def init_multi_step(self, num_scheduler_steps: int) -> None:
+        self.state.num_steps = num_scheduler_steps
+        self.state.current_step = 0
+
     def get_last_latency(self, now: float) -> Optional[float]:
         """Sets the last token time for Request level timings."""
         # If still in prefill phase, raise Error.
@@ -735,6 +754,7 @@ class SequenceGroupMetadata:
         token_chunk_size: The number of tokens to be processed (per sequence).
             None if chunking is not required.
         lora_request: LoRA request.
+        state: Internal state tied to this sequence group.
         computed_block_nums: The block numbers that are already computed,
             used in prefix caching.
         multi_modal_data: Multi modal data.
@@ -762,6 +782,7 @@ class SequenceGroupMetadata:
         token_chunk_size: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
         computed_block_nums: Optional[List[int]] = None,
+        state: Optional[SequenceGroupState] = None,
         multi_modal_data: Optional["MultiModalDataDict"] = None,
         encoder_seq_data: Optional[SequenceData] = None,
         cross_block_table: Optional[List[int]] = None,
@@ -777,6 +798,7 @@ class SequenceGroupMetadata:
         self.prompt_adapter_request = prompt_adapter_request
         self.computed_block_nums = computed_block_nums
         self.multi_modal_data = multi_modal_data
+        self.state = SequenceGroupState() if state is None else state
         self.encoder_seq_data = encoder_seq_data
         self.cross_block_table = cross_block_table
         self._token_chunk_size = token_chunk_size
@@ -815,6 +837,10 @@ class SequenceGroupMetadata:
         assert self._token_chunk_size is not None
         return self._token_chunk_size
 
+    def finish_step(self) -> None:
+        assert self.state.current_step < self.state.num_steps
+        self.state.current_step += 1
+
 
 class SequenceOutput:
     """The model output associated with a sequence.
@@ -952,6 +978,7 @@ class SamplerOutput:
 
     # On-device tensor containing the sampled token ids.
     sampled_token_ids: Optional[torch.Tensor] = None
+    sampled_token_ids_numpy: Optional[numpy.ndarray] = None
 
     # Spec decode metrics populated by workers.
     spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@@ -1086,6 +1113,33 @@ class ExecuteModelRequest:
     num_steps: int = 1
     # Finished request ids since last step.
     finished_requests_ids: List[str] = field(default_factory=list)
+    # The last sampled token ids for multi step decoding.
+    last_sampled_token_ids: Optional[torch.Tensor] = None
+
+    @property
+    def is_first_multi_step(self) -> bool:
+        # TODO: make this be able to handle batches with variable number of
+        # steps
+        assert len(self.seq_group_metadata_list) > 0
+        first_seq_group = self.seq_group_metadata_list[0]
+        return first_seq_group.state.current_step == 0
+
+    @property
+    def is_last_step(self) -> bool:
+        # TODO: make this be able to handle batches with variable number of
+        # steps
+        assert len(self.seq_group_metadata_list) > 0
+        first_seq_group = self.seq_group_metadata_list[0]
+        num_steps = first_seq_group.state.num_steps
+        current_step = first_seq_group.state.current_step
+        return num_steps - current_step == 1
+
+    @property
+    def current_step(self) -> int:
+        # TODO: make this be able to handle batches with variable number of
+        # steps
+        assert len(self.seq_group_metadata_list) > 0
+        return self.seq_group_metadata_list[0].state.current_step
 
     def clone(
         self, seq_group_metadata_list: List[SequenceGroupMetadata]
@@ -1102,4 +1156,5 @@ class ExecuteModelRequest:
             previous_hidden_states=self.previous_hidden_states,
             num_steps=self.num_steps,
             finished_requests_ids=self.finished_requests_ids,
-        )
+            last_sampled_token_ids=self.last_sampled_token_ids.clone()
+            if self.last_sampled_token_ids is not None else None)

+ 25 - 3
aphrodite/engine/args_tools.py

@@ -111,6 +111,7 @@ class EngineArgs:
     guided_decoding_backend: str = 'outlines'
     max_num_batched_tokens: Optional[int] = None
     max_num_seqs: int = 256
+    num_scheduler_steps: int = 1
     # Speculative Decoding Options
     num_lookahead_slots: int = 0
     speculative_model: Optional[str] = None
@@ -617,6 +618,11 @@ class EngineArgs:
             help="Category: API Options\n"
             "maximum number of sequences per iteration",
         )
+        parser.add_argument('--num-scheduler-steps',
+                            type=int,
+                            default=1,
+                            help=('Maximum number of forward steps per '
+                                  'scheduler call.'))
         # Speculative Decoding Options
         parser.add_argument("--num-lookahead-slots",
                             type=int,
@@ -970,19 +976,35 @@ class EngineArgs:
             disable_logprobs=self.disable_logprobs_during_spec_decoding,
         )
 
+        if self.num_scheduler_steps > 1:
+            raise NotImplementedError("Multi-step is not yet supported.")
+            if speculative_config is not None:
+                raise ValueError("Speculative decoding is not supported with "
+                                 "multi-step (--num-scheduler-steps > 1)")
+            if self.enable_chunked_prefill:
+                raise ValueError("Chunked prefill is not supported with "
+                                 "multi-step (--num-scheduler-steps > 1)")
+
+        # make sure num_lookahead_slots is set the higher value depending on
+        # if we are using speculative decoding or multi-step
+        num_lookahead_slots = max(self.num_lookahead_slots,
+                                  self.num_scheduler_steps - 1)
+        num_lookahead_slots = num_lookahead_slots \
+            if speculative_config is None \
+            else speculative_config.num_lookahead_slots
+
         scheduler_config = SchedulerConfig(
             max_num_batched_tokens=self.max_num_batched_tokens,
             max_num_seqs=self.max_num_seqs,
             max_model_len=model_config.max_model_len,
             is_attention_free=model_config.is_attention_free(),
             use_v2_block_manager=self.use_v2_block_manager,
-            num_lookahead_slots=(self.num_lookahead_slots
-                                 if speculative_config is None else
-                                 speculative_config.num_lookahead_slots),
+            num_lookahead_slots=num_lookahead_slots,
             delay_factor=self.scheduler_delay_factor,
             enable_chunked_prefill=self.enable_chunked_prefill,
             embedding_mode=model_config.embedding_mode,
             preemption_mode=self.preemption_mode,
+            num_scheduler_steps=self.num_scheduler_steps,
         )
 
         lora_config = LoRAConfig(

+ 5 - 0
aphrodite/processing/scheduler.py

@@ -803,6 +803,9 @@ class Scheduler:
                 curr_loras.add(lora_int_id)
             waiting_queue.popleft()
             self._allocate_and_set_running(seq_group)
+            seq_group.init_multi_step(
+                num_scheduler_steps=self._get_num_lookahead_slots(
+                    is_prefill=True) + 1)
             seq_groups.append(
                 ScheduledSequenceGroup(seq_group=seq_group,
                                        token_chunk_size=num_new_tokens))
@@ -1105,6 +1108,7 @@ class Scheduler:
                 computed_block_nums=common_computed_block_nums,
                 encoder_seq_data=encoder_seq_data,
                 cross_block_table=cross_block_table,
+                state=seq_group.state,
                 # `multi_modal_data` will only be present for the 1st comm
                 # between engine and worker.
                 # the subsequent comms can still use delta, but
@@ -1170,6 +1174,7 @@ class Scheduler:
                 slots.
         """
         num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
+        seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
 
         for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
             cows = self.block_manager.append_slots(seq, num_lookahead_slots)