Browse Source

spec decode: move ops.advane_step to flash attention backend (#1005)

AlpinDale 2 months ago
parent
commit
5c3b94de45

+ 15 - 5
aphrodite/attention/backends/flash_attn.py

@@ -18,7 +18,8 @@ from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
 if TYPE_CHECKING:
-    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import (ModelInputForGPUBuilder,
+                                               ModelInputForGPUWithSamplingMetadata)
 
 from aphrodite_flash_attn import (
     flash_attn_varlen_func as _flash_attn_varlen_func)
@@ -305,13 +306,12 @@ class FlashAttentionMetadata(AttentionMetadata):
         )
         return self._cached_decode_metadata
 
-    def advance_step(self, num_seqs: int, num_queries: int):
+    def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
+                     sampled_token_ids: Optional[torch.Tensor],
+                     block_size: int, num_seqs: int, num_queries: int):
         """
         Update metadata in-place to advance one decode step.
         """
-        # GPU in-place update is currently called separately through
-        # custom_ops.advance_step(). See draft_model_runner.
-        # TODO: Move this logic to the backend.
 
         # When using cudagraph, the num_seqs is padded to the next captured
         # batch sized, but num_queries tracks the actual number of requests in
@@ -350,6 +350,16 @@ class FlashAttentionMetadata(AttentionMetadata):
             self.seq_lens[i] += 1
         self.max_decode_seq_len = max(self.seq_lens)
 
+        ops.advance_step(num_seqs=num_seqs,
+                         num_queries=num_queries,
+                         block_size=block_size,
+                         input_tokens=model_input.input_tokens,
+                         sampled_token_ids=sampled_token_ids,
+                         input_positions=model_input.input_positions,
+                         seq_lens=self.seq_lens_tensor,
+                         slot_mapping=self.slot_mapping,
+                         block_tables=self.block_tables)
+
 
 class FlashAttentionMetadataBuilder(
         AttentionMetadataBuilder[FlashAttentionMetadata]):

+ 2 - 14
aphrodite/spec_decode/draft_model_runner.py

@@ -3,8 +3,6 @@ from typing import List, Optional
 import torch
 from loguru import logger
 
-from aphrodite import _custom_ops as ops
-
 try:
     from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
 except ModuleNotFoundError:
@@ -114,18 +112,8 @@ class TP1DraftModelRunner(ModelRunner):
         # Update attn_metadata
         attn_metadata = model_input.attn_metadata
         assert isinstance(attn_metadata, FlashAttentionMetadata)
-        attn_metadata.advance_step(num_seqs, num_queries)
-
-        # Update GPU tensors
-        ops.advance_step(num_seqs=num_seqs,
-                         num_queries=num_queries,
-                         block_size=self.block_size,
-                         input_tokens=model_input.input_tokens,
-                         sampled_token_ids=sampled_token_ids,
-                         input_positions=model_input.input_positions,
-                         seq_lens=attn_metadata.seq_lens_tensor,
-                         slot_mapping=attn_metadata.slot_mapping,
-                         block_tables=attn_metadata.block_tables)
+        attn_metadata.advance_step(model_input, sampled_token_ids,
+                                   self.block_size, num_seqs, num_queries)
 
         # Update sampling_metadata
         sampling_metadata = model_input.sampling_metadata

+ 4 - 14
aphrodite/worker/multi_step_model_runner.py

@@ -14,7 +14,6 @@ except ModuleNotFoundError:
 
 import torch
 
-from aphrodite import _custom_ops as ops
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
                                        IntermediateTensors, Logprob,
                                        SequenceGroupMetadata, SequenceOutput)
@@ -490,19 +489,10 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
         assert num_seqs >= num_queries
         attn_metadata = frozen_model_input.attn_metadata
         assert isinstance(attn_metadata, FlashAttentionMetadata)
-        attn_metadata.advance_step(num_seqs, num_queries)
-        # Update GPU tensors
-        ops.advance_step(
-            num_seqs=num_seqs,
-            num_queries=num_queries,
-            block_size=self.block_size,
-            input_tokens=frozen_model_input.input_tokens,
-            sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
-            input_positions=frozen_model_input.input_positions,
-            seq_lens=attn_metadata.seq_lens_tensor,
-            slot_mapping=attn_metadata.slot_mapping,
-            block_tables=attn_metadata.block_tables,
-        )
+        attn_metadata.advance_step(
+            frozen_model_input,
+            model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
+            num_seqs, num_queries)
         if frozen_model_input.seq_lens is not None:
             for i in range(num_queries):
                 frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]