|
@@ -1,6 +1,6 @@
|
|
|
"""Attention layer ROCm GPUs."""
|
|
|
from dataclasses import dataclass
|
|
|
-from typing import Any, Dict, List, Optional, Tuple, Type
|
|
|
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|
|
|
|
|
import torch
|
|
|
from loguru import logger
|
|
@@ -16,6 +16,10 @@ from aphrodite.attention.backends.utils import (CommonAttentionState,
|
|
|
from aphrodite.attention.ops.paged_attn import (PagedAttention,
|
|
|
PagedAttentionMetadata)
|
|
|
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from aphrodite.worker.model_runner import (
|
|
|
+ ModelInputForGPUWithSamplingMetadata)
|
|
|
+
|
|
|
_PARTITION_SIZE_ROCM = 512
|
|
|
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
|
|
|
|
@@ -178,6 +182,52 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
)
|
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
+ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
|
+ sampled_token_ids: Optional[torch.Tensor],
|
|
|
+ block_size: int, num_seqs: int, num_queries: int):
|
|
|
+ """
|
|
|
+ Update metadata in-place to advance one decode step.
|
|
|
+ """
|
|
|
+ # When using cudagraph, the num_seqs is padded to the next captured
|
|
|
+ # batch sized, but num_queries tracks the actual number of requests in
|
|
|
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
|
+ if num_seqs != num_queries:
|
|
|
+ assert num_seqs > num_queries
|
|
|
+ assert self.use_cuda_graph
|
|
|
+ assert self.num_prefills == 0
|
|
|
+ assert self.num_prefill_tokens == 0
|
|
|
+ assert self.num_decode_tokens == num_seqs
|
|
|
+ assert self.slot_mapping.shape == (num_seqs, )
|
|
|
+ assert self.seq_lens is not None
|
|
|
+ assert len(self.seq_lens) == num_seqs
|
|
|
+ assert self.seq_lens_tensor is not None
|
|
|
+ assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
|
+ assert self.max_query_len == 1
|
|
|
+ assert self.max_prefill_seq_len == 0
|
|
|
+ assert self.max_decode_seq_len == max(self.seq_lens)
|
|
|
+ assert self.query_start_loc is not None
|
|
|
+ assert self.query_start_loc.shape == (num_queries + 1, )
|
|
|
+ assert self.seq_start_loc is not None
|
|
|
+ assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
|
+ assert self.context_lens_tensor is not None
|
|
|
+ assert self.context_lens_tensor.shape == (num_queries, )
|
|
|
+ assert self.block_tables is not None
|
|
|
+ assert self.block_tables.shape[0] == num_seqs
|
|
|
+ # Update query lengths. Note that we update only queries and not seqs,
|
|
|
+ # since tensors may be padded due to captured cuda graph batch size
|
|
|
+ for i in range(num_queries):
|
|
|
+ self.seq_lens[i] += 1
|
|
|
+ self.max_decode_seq_len = max(self.seq_lens)
|
|
|
+ ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
|
+ num_queries=num_queries,
|
|
|
+ block_size=block_size,
|
|
|
+ input_tokens=model_input.input_tokens,
|
|
|
+ sampled_token_ids=sampled_token_ids,
|
|
|
+ input_positions=model_input.input_positions,
|
|
|
+ seq_lens=self.seq_lens_tensor,
|
|
|
+ slot_mapping=self.slot_mapping,
|
|
|
+ block_tables=self.block_tables)
|
|
|
+
|
|
|
|
|
|
class ROCmFlashAttentionMetadataBuilder(
|
|
|
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|