|
@@ -18,7 +18,8 @@ from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
|
|
|
from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
- from aphrodite.worker.model_runner import ModelInputForGPUBuilder
|
|
|
+ from aphrodite.worker.model_runner import (ModelInputForGPUBuilder,
|
|
|
+ ModelInputForGPUWithSamplingMetadata)
|
|
|
|
|
|
from aphrodite_flash_attn import (
|
|
|
flash_attn_varlen_func as _flash_attn_varlen_func)
|
|
@@ -305,13 +306,12 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|
|
)
|
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
- def advance_step(self, num_seqs: int, num_queries: int):
|
|
|
+ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
|
+ sampled_token_ids: Optional[torch.Tensor],
|
|
|
+ block_size: int, num_seqs: int, num_queries: int):
|
|
|
"""
|
|
|
Update metadata in-place to advance one decode step.
|
|
|
"""
|
|
|
- # GPU in-place update is currently called separately through
|
|
|
- # custom_ops.advance_step(). See draft_model_runner.
|
|
|
- # TODO: Move this logic to the backend.
|
|
|
|
|
|
# When using cudagraph, the num_seqs is padded to the next captured
|
|
|
# batch sized, but num_queries tracks the actual number of requests in
|
|
@@ -350,6 +350,16 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|
|
self.seq_lens[i] += 1
|
|
|
self.max_decode_seq_len = max(self.seq_lens)
|
|
|
|
|
|
+ ops.advance_step(num_seqs=num_seqs,
|
|
|
+ num_queries=num_queries,
|
|
|
+ block_size=block_size,
|
|
|
+ input_tokens=model_input.input_tokens,
|
|
|
+ sampled_token_ids=sampled_token_ids,
|
|
|
+ input_positions=model_input.input_positions,
|
|
|
+ seq_lens=self.seq_lens_tensor,
|
|
|
+ slot_mapping=self.slot_mapping,
|
|
|
+ block_tables=self.block_tables)
|
|
|
+
|
|
|
|
|
|
class FlashAttentionMetadataBuilder(
|
|
|
AttentionMetadataBuilder[FlashAttentionMetadata]):
|