Sfoglia il codice sorgente

rocm: enable multi-step scheduling for rocm (#1071)

AlpinDale 2 mesi fa
parent
commit
d9d287a288

+ 51 - 1
aphrodite/attention/backends/rocm_flash_attn.py

@@ -1,6 +1,6 @@
 """Attention layer ROCm GPUs."""
 from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from loguru import logger
@@ -16,6 +16,10 @@ from aphrodite.attention.backends.utils import (CommonAttentionState,
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
+if TYPE_CHECKING:
+    from aphrodite.worker.model_runner import (
+        ModelInputForGPUWithSamplingMetadata)
+
 _PARTITION_SIZE_ROCM = 512
 _ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
 
@@ -178,6 +182,52 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
         )
         return self._cached_decode_metadata
 
+    def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
+                     sampled_token_ids: Optional[torch.Tensor],
+                     block_size: int, num_seqs: int, num_queries: int):
+        """
+        Update metadata in-place to advance one decode step.
+        """
+        # When using cudagraph, the num_seqs is padded to the next captured
+        # batch sized, but num_queries tracks the actual number of requests in
+        # the batch. For --enforce-eager mode, num_seqs == num_queries
+        if num_seqs != num_queries:
+            assert num_seqs > num_queries
+            assert self.use_cuda_graph
+        assert self.num_prefills == 0
+        assert self.num_prefill_tokens == 0
+        assert self.num_decode_tokens == num_seqs
+        assert self.slot_mapping.shape == (num_seqs, )
+        assert self.seq_lens is not None
+        assert len(self.seq_lens) == num_seqs
+        assert self.seq_lens_tensor is not None
+        assert self.seq_lens_tensor.shape == (num_seqs, )
+        assert self.max_query_len == 1
+        assert self.max_prefill_seq_len == 0
+        assert self.max_decode_seq_len == max(self.seq_lens)
+        assert self.query_start_loc is not None
+        assert self.query_start_loc.shape == (num_queries + 1, )
+        assert self.seq_start_loc is not None
+        assert self.seq_start_loc.shape == (num_seqs + 1, )
+        assert self.context_lens_tensor is not None
+        assert self.context_lens_tensor.shape == (num_queries, )
+        assert self.block_tables is not None
+        assert self.block_tables.shape[0] == num_seqs
+        # Update query lengths. Note that we update only queries and not seqs,
+        # since tensors may be padded due to captured cuda graph batch size
+        for i in range(num_queries):
+            self.seq_lens[i] += 1
+        self.max_decode_seq_len = max(self.seq_lens)
+        ops.advance_step_flashattn(num_seqs=num_seqs,
+                                   num_queries=num_queries,
+                                   block_size=block_size,
+                                   input_tokens=model_input.input_tokens,
+                                   sampled_token_ids=sampled_token_ids,
+                                   input_positions=model_input.input_positions,
+                                   seq_lens=self.seq_lens_tensor,
+                                   slot_mapping=self.slot_mapping,
+                                   block_tables=self.block_tables)
+
 
 class ROCmFlashAttentionMetadataBuilder(
         CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

+ 1 - 1
aphrodite/worker/multi_step_model_runner.py

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
     from aphrodite.attention.backends.abstract import AttentionBackend
 
 
-MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
+MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer", "rocm-flash-attn"]
 
 
 def seq_output_builder():