Browse Source

feat: add single user mode (#927)

* feat: add single user mode

* guard against prefix caching

* guard against `n > 1`

* guard against beam search
AlpinDale 2 months ago
parent
commit
8d9f1fd4e6

+ 22 - 2
aphrodite/common/config.py

@@ -1094,12 +1094,15 @@ class SchedulerConfig:
             workers instead of an entire data. It should be enabled only
             when SPMD worker architecture is enabled. I.e.,
             APHRODITE_USE_RAY_SPMD_WORKER=1
+        single_user_mode: If True, we only allocate blocks for one sequence
+            and use the maximum sequence length as the number of tokens.
     """
 
     def __init__(self,
                  max_num_batched_tokens: Optional[int],
                  max_num_seqs: int,
                  max_model_len: int,
+                 cache_config: Optional["CacheConfig"] = None,
                  is_attention_free: bool = False,
                  use_v2_block_manager: bool = False,
                  num_lookahead_slots: int = 0,
@@ -1108,7 +1111,8 @@ class SchedulerConfig:
                  embedding_mode: Optional[bool] = False,
                  preemption_mode: Optional[str] = None,
                  num_scheduler_steps: int = 1,
-                 send_delta_data: bool = False) -> None:
+                 send_delta_data: bool = False,
+                 single_user_mode: bool = False) -> None:
         if max_num_batched_tokens is not None:
             self.max_num_batched_tokens = max_num_batched_tokens
         else:
@@ -1130,9 +1134,25 @@ class SchedulerConfig:
             logger.info(
                 "Chunked prefill is enabled with "
                 f"max_num_batched_tokens={self.max_num_batched_tokens}.")
+        if single_user_mode:
+            max_num_seqs = 1
+            if cache_config.enable_prefix_caching:
+                if not envs.APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE:
+                    logger.warning(
+                        "Chunked prefill is not supported in single user mode, "
+                        "this is not recommended and may lead to memory "
+                        "issues. Set APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE=1"
+                        " to force prefix caching.")
+                    cache_config.enable_prefix_caching = False
+                else:
+                    logger.warning(
+                        "Chunked prefill is enabled in single user mode, "
+                        "this is not recommended and may lead to memory "
+                        "issues.")
 
         self.max_num_seqs = max_num_seqs
         self.max_model_len = max_model_len
+        self.cache_config = cache_config
         self.is_attention_free = is_attention_free
         self.use_v2_block_manager = use_v2_block_manager
         self.num_lookahead_slots = num_lookahead_slots
@@ -1142,7 +1162,7 @@ class SchedulerConfig:
         self.preemption_mode = preemption_mode
         self.num_scheduler_steps = num_scheduler_steps
         self.send_delta_data = send_delta_data
-
+        self.single_user_mode = single_user_mode
         self._verify_args()
 
     def _verify_args(self) -> None:

+ 6 - 0
aphrodite/common/envs.py

@@ -55,6 +55,7 @@ if TYPE_CHECKING:
     APHRODITE_ALLOW_ENGINE_USE_RAY: bool = False
     APHRODITE_PLUGINS: Optional[List[str]] = None
     APHRODITE_RPC_GET_DATA_TIMEOUT_MS: int = 5000
+    APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE: bool = False
 
 
 def get_default_cache_root():
@@ -381,6 +382,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
     "APHRODITE_PLUGINS":
     lambda: None if "APHRODITE_PLUGINS" not in os.environ else os.environ[
         "APHRODITE_PLUGINS"].split(","),
+
+    # If set, forces prefix cache in single user mode
+    "APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE":
+    lambda: bool(int(os.getenv("APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE",
+                               "0"))),
 }
 
 # end-env-vars-definition

+ 10 - 0
aphrodite/common/sampling_params.py

@@ -10,6 +10,7 @@ from loguru import logger
 from typing_extensions import Annotated
 
 import aphrodite.common.envs as envs
+from aphrodite.common.config import SchedulerConfig
 
 _SAMPLING_EPS = 1e-5
 _MAX_TEMP = 1e-2
@@ -547,6 +548,15 @@ class SamplingParams(
         if self.top_k != -1:
             raise ValueError("top_k must be -1 when using greedy sampling.")
 
+    def _verify_with_scheduler_config(
+            self, scheduler_config: "SchedulerConfig") -> None:
+        if scheduler_config.single_user_mode:
+            if self.n > 1:
+                raise ValueError("n must be 1 in single user mode.")
+            if self.use_beam_search:
+                raise ValueError(
+                    "beam search is not supported in single user mode.")
+
     def update_from_generation_config(
             self,
             generation_config: Dict[str, Any],

+ 1 - 0
aphrodite/engine/aphrodite_engine.py

@@ -1049,6 +1049,7 @@ class AphroditeEngine:
 
         sampling_params.update_from_generation_config(
             self.generation_config_fields, seq.eos_token_id)
+        sampling_params._verify_with_scheduler_config(self.scheduler_config)
 
         # Create the sequence group.
         seq_group = SequenceGroup(

+ 9 - 0
aphrodite/engine/args_tools.py

@@ -119,6 +119,7 @@ class EngineArgs:
     max_num_batched_tokens: Optional[int] = None
     max_num_seqs: int = 256
     num_scheduler_steps: int = 1
+    single_user_mode: bool = False
     # Speculative Decoding Options
     num_lookahead_slots: int = 0
     speculative_model: Optional[str] = None
@@ -641,6 +642,12 @@ class EngineArgs:
             help="Category: API Options\n"
             "maximum number of sequences per iteration",
         )
+        parser.add_argument('--single-user-mode',
+                            action='store_true',
+                            help='Category: API Options\n'
+                            'If True, we only allocate blocks for one sequence '
+                            'and use the maximum sequence length as the number '
+                            'of tokens.')
         parser.add_argument('--num-scheduler-steps',
                             type=int,
                             default=1,
@@ -1047,6 +1054,7 @@ class EngineArgs:
             max_num_batched_tokens=self.max_num_batched_tokens,
             max_num_seqs=self.max_num_seqs,
             max_model_len=model_config.max_model_len,
+            cache_config=cache_config,
             is_attention_free=model_config.is_attention_free(),
             use_v2_block_manager=self.use_v2_block_manager,
             num_lookahead_slots=num_lookahead_slots,
@@ -1057,6 +1065,7 @@ class EngineArgs:
             num_scheduler_steps=self.num_scheduler_steps,
             send_delta_data=(APHRODITE_USE_RAY_SPMD_WORKER and
                              parallel_config.use_ray),
+            single_user_mode=self.single_user_mode,
         )
 
         if not HAS_TRITON and self.enable_lora:

+ 54 - 24
aphrodite/task_handler/model_runner.py

@@ -1059,6 +1059,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
         max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
         max_num_seqs = self.scheduler_config.max_num_seqs
+
+        single_seq_mode = self.scheduler_config.single_user_mode
+        if single_seq_mode and rank == 0:
+            logger.info("Running in single sequence profiling mode")
+
         # This represents the maximum number of different requests
         # that will have unique loras, an therefore the max amount of memory
         # consumption create dummy lora request copies from the lora request
@@ -1078,10 +1083,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                     self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                      rank=LORA_WARMUP_RANK)
                     dummy_lora_requests.append(dummy_lora_request)
-                dummy_lora_requests_per_seq = [
-                    dummy_lora_requests[idx % len(dummy_lora_requests)]
-                    for idx in range(max_num_seqs)
-                ]
+                if single_seq_mode:
+                    dummy_lora_requests_per_seq = [dummy_lora_requests[0]]
+                else:
+                    dummy_lora_requests_per_seq = [
+                        dummy_lora_requests[idx % len(dummy_lora_requests)]
+                        for idx in range(max_num_seqs)
+                    ]
 
         # Profile memory usage with max_num_sequences sequences and the total
         # number of tokens equal to max_num_batched_tokens.
@@ -1095,39 +1103,61 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
 
         max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
             self.model_config)
-        if max_mm_tokens > 0:
-            max_num_seqs_orig = max_num_seqs
-            max_num_seqs = min(max_num_seqs,
-                               max_num_batched_tokens // max_mm_tokens)
-            if max_num_seqs < 1:
-                expr = (f"min({max_num_seqs_orig}, "
+        if max_mm_tokens > 0 and not single_seq_mode:
+                max_num_seqs_orig = max_num_seqs
+                max_num_seqs = min(max_num_seqs,
+                                max_num_batched_tokens // max_mm_tokens)
+                if max_num_seqs < 1:
+                    expr = (f"min({max_num_seqs_orig}, "
                         f"{max_num_batched_tokens} // {max_mm_tokens})")
-                logger.warning(
-                    f"Computed max_num_seqs ({expr}) to be less than 1. "
-                    "Setting it to the minimum value of 1.")
-                max_num_seqs = 1
+                    logger.warning(
+                        f"Computed max_num_seqs ({expr}) to be less than 1. "
+                        "Setting it to the minimum value of 1.")
+                    max_num_seqs = 1
         batch_size = 0
-        for group_id in range(max_num_seqs):
-            seq_len = (max_num_batched_tokens // max_num_seqs +
-                       (group_id < max_num_batched_tokens % max_num_seqs))
-            batch_size += seq_len
-
+        if single_seq_mode:
+            seq_len = max_num_batched_tokens
+            batch_size = seq_len
+            
             seq_data, dummy_multi_modal_data = self.input_registry \
                 .dummy_data_for_profiling(self.model_config,
-                                          seq_len,
-                                          self.mm_registry)
+                                        seq_len,
+                                        self.mm_registry)
 
             seq = SequenceGroupMetadata(
-                request_id=str(group_id),
+                request_id="0",
                 is_prompt=True,
-                seq_data={group_id: seq_data},
+                seq_data={0: seq_data},
                 sampling_params=sampling_params,
                 block_tables=None,
-                lora_request=dummy_lora_requests_per_seq[group_id]
+                lora_request=dummy_lora_requests_per_seq[0]
                 if dummy_lora_requests_per_seq else None,
                 multi_modal_data=dummy_multi_modal_data,
             )
             seqs.append(seq)
+        else:
+            # Original multi-sequence profiling logic
+            for group_id in range(max_num_seqs):
+                seq_len = (max_num_batched_tokens // max_num_seqs +
+                        (group_id < max_num_batched_tokens % max_num_seqs))
+                batch_size += seq_len
+                
+                seq_data, dummy_multi_modal_data = self.input_registry \
+                    .dummy_data_for_profiling(self.model_config,
+                                            seq_len,
+                                            self.mm_registry)
+
+                seq = SequenceGroupMetadata(
+                    request_id=str(group_id),
+                    is_prompt=True,
+                    seq_data={group_id: seq_data},
+                    sampling_params=sampling_params,
+                    block_tables=None,
+                    lora_request=dummy_lora_requests_per_seq[group_id]
+                    if dummy_lora_requests_per_seq else None,
+                    multi_modal_data=dummy_multi_modal_data,
+                )
+                seqs.append(seq)
 
         # Run the model with the dummy inputs.
         num_layers = self.model_config.get_num_layers(self.parallel_config)

+ 26 - 4
aphrodite/task_handler/worker.py

@@ -216,11 +216,33 @@ class Worker(LocalOrDistributedWorkerBase):
             num_gpu_blocks = 0
             num_cpu_blocks = 0
         else:
-            num_gpu_blocks = int(
-                (total_gpu_memory * self.cache_config.gpu_memory_utilization -
-                 peak_memory) // cache_block_size)
+            # if single_user_mode is set to True, we only allocate enough blocks
+            # for one sequence
+            if self.scheduler_config.single_user_mode:
+                num_gpu_blocks = (self.model_config.max_model_len +
+                                  self.cache_config.block_size - 1
+                                  ) // self.cache_config.block_size
+                max_possible_blocks = int(
+                    (total_gpu_memory *
+                     self.cache_config.gpu_memory_utilization -
+                     peak_memory) // cache_block_size)
+                num_gpu_blocks = min(num_gpu_blocks, max_possible_blocks)
+                if tp_rank == 0:
+                    logger.info(
+                        f"Single sequence mode: Allocating {num_gpu_blocks} "
+                        "blocks "
+                        f"({num_gpu_blocks * self.cache_config.block_size} "
+                        "tokens)")
+            else:
+                # Original logic for multi-sequence mode
+                num_gpu_blocks = int(
+                    (total_gpu_memory *
+                     self.cache_config.gpu_memory_utilization -
+                     peak_memory) // cache_block_size)
+
             num_cpu_blocks = int(self.cache_config.swap_space_bytes //
-                                 cache_block_size)
+                                cache_block_size)
+
         num_gpu_blocks = max(num_gpu_blocks, 0)
         num_cpu_blocks = max(num_cpu_blocks, 0)
         if self.model_runner.lora_manager: