Преглед на файлове

Revert "feat: add support for chunked prefill + prefix caching (#871)"

This reverts commit abfd4465ca5f62ffaf7d37aee247534f80985998.
AlpinDale преди 3 месеца
родител
ревизия
0bf916eabd

+ 6 - 13
aphrodite/processing/block_manager_v1.py

@@ -680,20 +680,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
             for block in block_table:
                 block.last_accessed = access_time
 
-    def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
+    def compute_full_blocks_in_seq(self, seq: Sequence):
         if seq.seq_id not in self.block_tables:
             return
-
-        # When chunked prefill is enabled, the computed full blocks
-        # should be calculated based on the number of computed tokens.
-        max_computed_tokens = (seq.data.get_num_computed_tokens() +
-                               token_chunk_size)
-        computed_full_blocks = max_computed_tokens // self.block_size
-
+        max_full_block = seq.get_len() // self.block_size - 1
         block_table = self.block_tables[seq.seq_id]
-        if computed_full_blocks == 0:
+        if max_full_block == -1:
             return
-        for i in reversed(range(computed_full_blocks)):
+        for i in reversed(range(max_full_block)):
             if block_table[i].computed:
                 break
             block_table[i].computed = True
@@ -723,11 +717,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
         ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
         return commonprefix([ids for ids in ids_list if ids != []])
 
-    def mark_blocks_as_computed(self, seq_group: SequenceGroup,
-                                token_chunk_size: int):
+    def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         if self.enable_caching:
             for seq in seq_group.get_seqs():
-                self.compute_full_blocks_in_seq(seq, token_chunk_size)
+                self.compute_full_blocks_in_seq(seq)
 
     def get_prefix_cache_hit_rate(self, device: Device) -> float:
         if device == Device.GPU:

+ 1 - 2
aphrodite/processing/block_manager_v2.py

@@ -284,8 +284,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
             self._last_access_blocks_tracker.update_last_access(
                 seq.seq_id, now)
 
-    def mark_blocks_as_computed(self, seq_group: SequenceGroup,
-                                token_chunk_size: int):
+    def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         # The only need for mark block as computed is for prefix caching,
         # while currently we could determine whether one block is computed
         # or not by check whether it has content hash.

+ 1 - 2
aphrodite/processing/placeholder_block_space_manager.py

@@ -80,8 +80,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
                                       seq_group: SequenceGroup) -> List[int]:
         return None  # type: ignore
 
-    def mark_blocks_as_computed(self, seq_group: SequenceGroup,
-                                token_chunk_size: int):
+    def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         pass
 
     def get_prefix_cache_hit_rate(self, device: Device) -> float:

+ 6 - 24
aphrodite/processing/scheduler.py

@@ -1152,8 +1152,7 @@ class Scheduler:
         # will crash the Aphrodite instance / will not retry.
         for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
             self.block_manager.mark_blocks_as_computed(
-                scheduled_seq_group.seq_group,
-                scheduled_seq_group.token_chunk_size)
+                scheduled_seq_group.seq_group)
 
         return seq_group_metadata_list, scheduler_outputs
 
@@ -1345,27 +1344,10 @@ class Scheduler:
         for seq in seqs:
             num_new_tokens += seq.get_num_new_tokens()
         assert num_new_tokens > 0
-        # Chunk if a running request cannot fit in the given budget.
-        # If number of seq > 1, it means it is doing beam search
-        # in a decode phase. Do not chunk.
+        # Chunk if a running request cannot fit in.
+        # If number of seq > 1, it means it is doing beam search in a
+        # decode phase. Do not chunk in that case.
         if enable_chunking and len(seqs) == 1:
-            remaining_token_budget = budget.remaining_token_budget()
-            if self.cache_config.enable_prefix_caching:
-                # When prefix caching is enabled, we always allocate
-                # the number of new tokens that is dividable by the block size
-                # to avoid partial block matching.
-                block_size = self.cache_config.block_size
-                reminder = budget.token_budget % block_size
-                if reminder != 0:
-                    raise ValueError("When enabling chunked prefill and "
-                                     "prefix caching, max_num_batched_tokens "
-                                     "(chunk size) must be dividable by "
-                                     "block size, but got chunk_size "
-                                     f"({budget.token_budget}) % block_size "
-                                     f"({block_size}) = {reminder}")
-                if remaining_token_budget < num_new_tokens:
-                    num_new_tokens = (remaining_token_budget //
-                                      block_size) * block_size
-            else:
-                num_new_tokens = min(num_new_tokens, remaining_token_budget)
+            num_new_tokens = min(num_new_tokens,
+                                 budget.remaining_token_budget())
         return num_new_tokens

+ 12 - 37
aphrodite/task_handler/model_runner.py

@@ -501,48 +501,23 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
                             and self.sliding_window is None
                             and inter_data.is_prompt)
         inter_data.prefix_cache_hit = prefix_cache_hit
-
-        if not prefix_cache_hit:
-            return
-
-        assert computed_block_nums is not None
-        # The cache hit prompt tokens in this sequence. Note that
-        # this may be larger than the sequence length if chunked
-        # prefill is enabled.
-        prefix_cache_len = len(computed_block_nums) * self.block_size
-        # The number of so far computed prompt tokens in this sequence.
-        context_len = inter_data.context_lens[seq_idx]
-        # The total number of prompt tokens in this sequence.
-        # When chunked prefill is enabled, this is the token number of
-        # computed chunks + current chunk.
-        seq_len = inter_data.seq_lens[seq_idx]
-        if prefix_cache_len <= context_len:
-            # We already passed the cache hit region,
-            # so do normal computation.
-            pass
-        elif context_len < prefix_cache_len < seq_len:
-            # Partial hit. Compute the missing part.
-            uncomputed_start = prefix_cache_len - context_len
+        if self.chunked_prefill_enabled and prefix_cache_hit:
+            raise RuntimeError(
+                "chunked prefill cannot be used with prefix caching now.")
+
+        # If prefix cache is hit, advance context length to bypass
+        # hit blocks. Accordingly, input tokens, position and query length
+        # have to be updated.
+        if prefix_cache_hit:
+            assert computed_block_nums is not None
+            context_len = len(computed_block_nums) * self.block_size
             inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
-                seq_idx][uncomputed_start:]
+                seq_idx][context_len:]
             inter_data.input_positions[seq_idx] = inter_data.input_positions[
-                seq_idx][uncomputed_start:]
-            context_len = prefix_cache_len
-
+                seq_idx][context_len:]
             inter_data.context_lens[seq_idx] = context_len
             inter_data.query_lens[
                 seq_idx] = inter_data.seq_lens[seq_idx] - context_len
-        elif seq_len <= prefix_cache_len:
-            # Full hit. Only compute the last token to avoid
-            # erroneous behavior. FIXME: Ideally we should directly
-            # mark all tokens as computed in the scheduler and do not
-            # schedule this sequence, so this case should not happen.
-            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
-                seq_idx][-1:]
-            inter_data.input_positions[seq_idx] = inter_data.input_positions[
-                seq_idx][-1:]
-            inter_data.query_lens[seq_idx] = 1
-            inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
 
     def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
                                     seq_idx: int,

+ 0 - 66
tests/basic_correctness/test_chunked_prefill.py

@@ -6,7 +6,6 @@ prefill requests are chunked.
 
 Run `pytest tests/models/test_chunked_prefill.py`.
 """
-from contextlib import nullcontext
 
 import pytest
 
@@ -152,68 +151,3 @@ def test_models_with_fp8_kv_cache(
         name_0="no_chunked_prefill",
         name_1="chunked_prefill",
     )
-
-
-@pytest.mark.parametrize("max_tokens", [16])
-@pytest.mark.parametrize("enforce_eager", [False])
-@pytest.mark.parametrize("chunk_size", [30, 32])
-@pytest.mark.parametrize("use_v2_block_manager", [False, True])
-# NOTE: Increasing this in this suite will fail CI because we currently cannot
-# reset distributed env properly. Use a value > 1 just when you test.
-@pytest.mark.parametrize("tensor_parallel_size", [1])
-def test_with_prefix_caching(
-    aphrodite_runner,
-    max_tokens: int,
-    enforce_eager: bool,
-    chunk_size: int,
-    use_v2_block_manager: bool,
-    tensor_parallel_size: int,
-) -> None:
-    """
-    Checks exact match decode with and without prefix caching
-    with chunked prefill enabled.
-    """
-    model = "meta-llama/Llama-2-7b-chat-hf"
-    # The common prompt has 142 tokens with Llama-2 tokenizer.
-    common_prompt = "You are a helpful AI assistant " * 20
-    unique_prompts = [
-        "Question",  # Warmup
-        "Question",  # Fully cached
-        "Another question",  # Partial cached
-    ]
-    full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]
-
-    max_num_batched_tokens = max_num_seqs = chunk_size
-    outputs = {}  # type: ignore
-    check_result = True
-    for enable in (True, False):
-        with aphrodite_runner(
-                model,
-                dtype="half",
-                max_num_batched_tokens=max_num_batched_tokens,
-                enable_chunked_prefill=True,
-                enable_prefix_caching=enable,
-                tensor_parallel_size=tensor_parallel_size,
-                use_v2_block_manager=use_v2_block_manager,
-                enforce_eager=enforce_eager,
-                max_num_seqs=max_num_seqs,
-        ) as aphrodite_model:
-            # It should fail when prefix caching is enable and chunk
-            # size is not a multiple of block size (16).
-            should_fail = chunk_size % 16 != 0 and enable
-            check_result &= not should_fail
-            outputs[enable] = []
-            # Send the request one-by-one to ensure the cache is populated.
-            with pytest.raises(ValueError) if should_fail else nullcontext():
-                for prompt in full_prompts:
-                    outputs[enable] += aphrodite_model.generate_greedy(
-                        [prompt], max_tokens)
-
-    # Check results only if we did not expect a failure.
-    if check_result:
-        check_outputs_equal(
-            outputs_0_lst=outputs[False],
-            outputs_1_lst=outputs[True],
-            name_0="w/o prefix caching",
-            name_1="with prefix caching",
-        )

+ 0 - 40
tests/core/test_block_manager.py

@@ -596,43 +596,3 @@ def test_sliding_window_multi_seq():
 
     # assert all blocks are free now
     assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
-
-
-def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
-    """When prefix cache and chunked prefill are enabled, the block manager
-    should only mark a chunk of blocks as computed instead of all blocks.
-    """
-
-    block_size = 4
-    num_cpu_blocks = 0
-    num_gpu_blocks = 16
-    block_manager = BlockSpaceManagerV1(block_size,
-                                        num_gpu_blocks,
-                                        num_cpu_blocks,
-                                        watermark=0,
-                                        enable_caching=True)
-
-    # Set prompt size to have num_gpu_blocks - 1 full blocks.
-    prompt_length = block_size * num_gpu_blocks - 1
-
-    # Allocate (reserve) all blocks.
-    _, seq_group = create_dummy_prompt("0",
-                                       prompt_length,
-                                       block_size=block_size)
-    block_manager.allocate(seq_group)
-    assert seq_group.seqs[0].n_blocks == num_gpu_blocks
-
-    # 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
-    token_chunk_size = int(block_size * 2.5)
-    block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
-    computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
-    assert len(computed_blocks) == 2
-
-    # Actual computed tokens.
-    seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
-
-    # 2nd chunk: Complete 3rd block and additional 4 blocks.
-    token_chunk_size = int(block_size * 4.5)
-    block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
-    computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
-    assert len(computed_blocks) == 7

+ 0 - 39
tests/core/test_chunked_prefill_scheduler.py

@@ -581,42 +581,3 @@ def test_chunked_prefill_max_seqs():
     assert len(get_sequence_groups(out)) == max_seqs
     assert not running[0].is_prefill()
     assert not running[1].is_prefill()
-
-
-def test_perfix_caching():
-    """Verify allocating full blocks when prefix caching is enabled."""
-    block_size = 4
-    max_seqs = 10
-    max_model_len = 80
-    max_num_batched_tokens = 64
-    scheduler_config = SchedulerConfig(max_num_batched_tokens,
-                                       max_seqs,
-                                       max_model_len,
-                                       enable_chunked_prefill=True)
-    cache_config = CacheConfig(block_size,
-                               1.0,
-                               1,
-                               "auto",
-                               enable_prefix_caching=True)
-    cache_config.num_cpu_blocks = 0
-    cache_config.num_gpu_blocks = 32
-    scheduler = Scheduler(scheduler_config, cache_config, None)
-    running: List[SequenceGroup] = []
-
-    # Add seq groups to scheduler.
-    for i in range(2):
-        _, seq_group = create_dummy_prompt(str(i),
-                                           block_size=block_size,
-                                           prompt_length=50)
-        scheduler.add_seq_group(seq_group)
-        running.append(seq_group)
-
-    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
-    assert set(get_sequence_groups(out)) == set(running)
-    assert seq_group_meta[0].token_chunk_size == 50
-    # Verify it is chunked. Note that although the budget is 64-50=14,
-    # we only allocate full blocks for prefix caching, so only 4*(14//4)=12
-    # tokens are allocated.
-    assert seq_group_meta[1].token_chunk_size == 12
-    assert out.num_prefill_groups == 2
-    assert out.num_batched_tokens == 62