123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- """Attention layer ROCm GPUs."""
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional, Tuple, Type
- import torch
- from loguru import logger
- import aphrodite.common.envs as envs
- from aphrodite.attention.backends.abstract import (AttentionBackend,
- AttentionImpl,
- AttentionMetadata,
- AttentionType)
- from aphrodite.attention.backends.utils import (CommonAttentionState,
- CommonMetadataBuilder)
- from aphrodite.attention.ops.paged_attn import (PagedAttention,
- PagedAttentionMetadata)
- class ROCmFlashAttentionBackend(AttentionBackend):
- @staticmethod
- def get_name() -> str:
- return "rocm-flash-attn"
- @staticmethod
- def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
- return ROCmFlashAttentionImpl
- @staticmethod
- def get_metadata_cls() -> Type["AttentionMetadata"]:
- return ROCmFlashAttentionMetadata
- @staticmethod
- def get_state_cls() -> Type["CommonAttentionState"]:
- return CommonAttentionState
- @staticmethod
- def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
- return ROCmFlashAttentionMetadataBuilder
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- ) -> Tuple[int, ...]:
- return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
- num_kv_heads, head_size)
- @staticmethod
- def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
- ) -> None:
- PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
- @staticmethod
- def copy_blocks(
- kv_caches: List[torch.Tensor],
- src_to_dists: torch.Tensor,
- ) -> None:
- PagedAttention.copy_blocks(kv_caches, src_to_dists)
- @dataclass
- class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
- """Metadata for FlashAttentionBackend.
- NOTE: Any python object stored here is not updated when it is
- cuda-graph replayed. If you have values that need to be changed
- dynamically, it should be stored in tensor. The tensor has to be
- updated from `CUDAGraphRunner.forward` API.
- """
- # (batch_size,). The sequence length per sequence. Sequence length means
- # the computed tokens + new tokens None if it is a decoding.
- seq_lens: Optional[List[int]]
- # seq_lens stored as a tensor.
- seq_lens_tensor: Optional[torch.Tensor]
- # NOTE: Definition of context_len, query_len, and seq_len.
- # |---------- N-1 iteration --------|
- # |---------------- N iteration ---------------------|
- # |- tokenA -|......................|-- newTokens ---|
- # |---------- context_len ----------|
- # |-------------------- seq_len ----------------------|
- # |-- query_len ---|
- # Maximum query length in the batch. None for decoding.
- max_query_len: Optional[int]
- # Maximum sequence length among prefill batch. 0 if there are decoding
- # requests only.
- max_prefill_seq_len: int
- # Maximum sequence length among decode batch. 0 if there are prefill
- # requests only.
- max_decode_seq_len: int
- # (batch_size + 1,). The cumulative subquery lengths of the sequences in
- # the batch, used to index into subquery. E.g., if the subquery length
- # is [4, 6], it is [0, 4, 10].
- query_start_loc: Optional[torch.Tensor]
- # (batch_size + 1,). The cumulative sequence lengths of the sequences in
- # the batch, used to index into sequence. E.g., if the sequence length is
- # [4, 6], it is [0, 4, 10].
- seq_start_loc: Optional[torch.Tensor]
- # Whether or not if cuda graph is enabled.
- # Cuda-graph is currently enabled for decoding only.
- # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
- use_cuda_graph: bool
- # (batch_size,) A tensor of context lengths (tokens that are computed
- # so far).
- context_lens_tensor: Optional[torch.Tensor]
- _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
- _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
- @property
- def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
- if self.num_prefills == 0:
- return None
- if self._cached_prefill_metadata is not None:
- return self._cached_prefill_metadata
- assert self.seq_lens is not None
- assert self.seq_lens_tensor is not None
- assert self.query_start_loc is not None
- assert self.context_lens_tensor is not None
- assert self.block_tables is not None
- assert self.seq_start_loc is not None
- self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
- num_prefills=self.num_prefills,
- num_prefill_tokens=self.num_prefill_tokens,
- num_decode_tokens=0,
- slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
- seq_lens=self.seq_lens[:self.num_prefills],
- seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
- max_query_len=self.max_query_len,
- max_prefill_seq_len=self.max_prefill_seq_len,
- max_decode_seq_len=0,
- query_start_loc=self.query_start_loc[:self.num_prefills + 1],
- seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
- context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
- block_tables=self.block_tables[:self.num_prefills],
- use_cuda_graph=False,
- )
- return self._cached_prefill_metadata
- @property
- def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
- if self.num_decode_tokens == 0:
- return None
- if self._cached_decode_metadata is not None:
- return self._cached_decode_metadata
- assert self.block_tables is not None
- assert self.seq_lens_tensor is not None
- self._cached_decode_metadata = ROCmFlashAttentionMetadata(
- num_prefills=0,
- num_prefill_tokens=0,
- num_decode_tokens=self.num_decode_tokens,
- slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
- seq_lens=None,
- seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
- max_query_len=None,
- max_prefill_seq_len=0,
- max_decode_seq_len=self.max_decode_seq_len,
- query_start_loc=None,
- seq_start_loc=None,
- context_lens_tensor=None,
- block_tables=self.block_tables[self.num_prefills:],
- use_cuda_graph=self.use_cuda_graph,
- )
- return self._cached_decode_metadata
- class ROCmFlashAttentionMetadataBuilder(
- CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
- _metadata_cls = ROCmFlashAttentionMetadata
- def _make_alibi_bias(alibi_slopes: torch.Tensor,
- dtype: torch.dtype,
- seq_lens: Optional[List[int]],
- make_attn_mask: bool = True) -> List[torch.Tensor]:
- attn_biases = []
- if seq_lens:
- for seq_len in seq_lens:
- bias = torch.arange(seq_len, dtype=dtype)
- # NOTE(zhuohan): HF uses
- # `bias = bias[None, :].repeat(seq_len, 1)`
- # here. We find that both biases give the same results, but
- # the bias below more accurately follows the original ALiBi
- # paper.
- bias = bias[None, :] - bias[:, None]
- num_heads = alibi_slopes.shape[0]
- bias = bias[None, :].repeat(
- (num_heads, 1, 1)).to(alibi_slopes.device)
- bias.mul_(alibi_slopes[:, None, None])
- if make_attn_mask:
- inf_mask = torch.empty(
- (1, seq_len, seq_len),
- dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
- alibi_slopes.device)
- attn_biases.append((bias + inf_mask).to(dtype))
- else:
- attn_biases.append(bias.to(dtype))
- return attn_biases
- class ROCmFlashAttentionImpl(AttentionImpl):
- """
- If the input tensors contain prompt tokens, the layout is as follows:
- |<--------------- num_prompt_tokens -------------->|
- |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
- Otherwise, the layout is as follows:
- |<------------------ num_generation_tokens (M) ----------------->|
- |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
- Generation tokens can contain padding when cuda-graph is used.
- Currently, prompt tokens don't contain any padding.
- The prompts might have different lengths, while the generation tokens
- always have length 1.
- If chunked prefill is enabled, prefill tokens and decode tokens can be
- batched together in a flattened 1D query.
- |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
- |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
- Currently, cuda graph is disabled for chunked prefill, meaning there's no
- padding between prefill and decode tokens.
- """
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: int,
- alibi_slopes: Optional[List[float]],
- sliding_window: Optional[int],
- kv_cache_dtype: str,
- blocksparse_params: Optional[Dict[str, Any]] = None,
- logits_soft_cap: Optional[float] = None,
- ) -> None:
- if blocksparse_params is not None:
- raise ValueError(
- "ROCmFlashAttention does not support blocksparse attention.")
- if logits_soft_cap is not None:
- raise ValueError(
- "ROCmFlashAttention does not support attention logits soft "
- "capping.")
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_kv_heads
- if alibi_slopes is not None:
- alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
- self.alibi_slopes = alibi_slopes
- self.sliding_window = ((sliding_window, sliding_window)
- if sliding_window is not None else (-1, -1))
- self.kv_cache_dtype = kv_cache_dtype
- assert self.num_heads % self.num_kv_heads == 0
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- supported_head_sizes = PagedAttention.get_supported_head_sizes()
- if head_size not in supported_head_sizes:
- raise ValueError(
- f"Head size {head_size} is not supported by PagedAttention. "
- f"Supported head sizes are: {supported_head_sizes}.")
- self.use_naive_attn = False
- # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
- self.use_triton_flash_attn = envs.APHRODITE_USE_TRITON_FLASH_ATTN
- if self.use_triton_flash_attn:
- from aphrodite.attention.ops.triton_flash_attn import ( # noqa: F401
- triton_attention)
- self.attn_func = triton_attention
- logger.debug("Using Triton FA in ROCmBackend")
- if self.sliding_window != (-1, -1):
- logger.warning("ROCm Triton FA does not currently support "
- "sliding window attention. If using half "
- "precision, please try using the ROCm CK "
- "FA backend instead by setting the env var "
- "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
- else:
- # if not using triton, navi3x/navi21/navi10 do not use flash-attn
- # either
- if torch.cuda.get_device_capability()[0] != 9:
- self.use_naive_attn = True
- else:
- try:
- from flash_attn import flash_attn_varlen_func # noqa: F401
- self.attn_func = flash_attn_varlen_func
- logger.debug("Using CK FA in ROCmBackend")
- except ModuleNotFoundError:
- self.use_naive_attn = True
- if self.use_naive_attn:
- self.attn_func = _sdpa_attention
- logger.debug("Using naive attention in ROCmBackend")
- def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
- """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
- tokens, n_kv_heads, head_dim = x.shape
- return (x[:, :,
- None, :].expand(tokens, n_kv_heads, n_rep,
- head_dim).reshape(tokens, n_kv_heads * n_rep,
- head_dim))
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: ROCmFlashAttentionMetadata,
- k_scale: float = 1.0,
- v_scale: float = 1.0,
- attn_type: AttentionType = AttentionType.DECODER,
- ) -> torch.Tensor:
- """Forward pass with FlashAttention and PagedAttention.
- Args:
- query: shape = [num_tokens, num_heads * head_size]
- key: shape = [num_tokens, num_kv_heads * head_size]
- value: shape = [num_tokens, num_kv_heads * head_size]
- kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
- attn_metadata: Metadata for attention.
- Returns:
- shape = [num_tokens, num_heads * head_size]
- """
- if attn_type != AttentionType.DECODER:
- raise NotImplementedError("Encoder self-attention and "
- "encoder/decoder cross-attention "
- "are not implemented for "
- "ROCmFlashAttentionImpl")
- num_tokens, hidden_size = query.shape
- # Reshape the query, key, and value tensors.
- query = query.view(-1, self.num_heads, self.head_size)
- key = key.view(-1, self.num_kv_heads, self.head_size)
- value = value.view(-1, self.num_kv_heads, self.head_size)
- if kv_cache is not None:
- key_cache, value_cache = PagedAttention.split_kv_cache(
- kv_cache, self.num_kv_heads, self.head_size)
- # Reshape the input keys and values and store them in the cache.
- # If kv_cache is not provided, the new key and value tensors are
- # not cached. This happens during the initial memory profiling run.
- PagedAttention.write_to_paged_cache(
- key,
- value,
- key_cache,
- value_cache,
- attn_metadata.slot_mapping,
- self.kv_cache_dtype,
- k_scale,
- v_scale,
- )
- num_prefill_tokens = attn_metadata.num_prefill_tokens
- num_decode_tokens = attn_metadata.num_decode_tokens
- assert key.shape[0] == num_prefill_tokens + num_decode_tokens
- assert value.shape[0] == num_prefill_tokens + num_decode_tokens
- output = torch.empty_like(query)
- # Query for decode. KV is not needed because it is already cached.
- decode_query = query[num_prefill_tokens:]
- # QKV for prefill.
- query = query[:num_prefill_tokens]
- key = key[:num_prefill_tokens]
- value = value[:num_prefill_tokens]
- assert query.shape[0] == num_prefill_tokens
- assert decode_query.shape[0] == num_decode_tokens
- if prefill_meta := attn_metadata.prefill_metadata:
- # Prompt run.
- assert prefill_meta.seq_lens is not None
- if kv_cache is None or prefill_meta.block_tables.numel() == 0:
- # triton attention
- # When block_tables are not filled, it means q and k are the
- # prompt, and they have the same length.
- attn_masks = None
- if self.use_triton_flash_attn:
- if self.alibi_slopes is not None:
- attn_masks = _make_alibi_bias(
- self.alibi_slopes,
- query.dtype,
- attn_metadata.seq_lens,
- make_attn_mask=False) # type: ignore
- out, _ = self.attn_func(
- query,
- key,
- value,
- None,
- prefill_meta.seq_start_loc,
- prefill_meta.seq_start_loc,
- prefill_meta.max_prefill_seq_len,
- prefill_meta.max_prefill_seq_len,
- True,
- self.scale,
- attn_masks[0][None]
- if attn_masks is not None else None,
- )
- elif self.use_naive_attn:
- if self.num_kv_heads != self.num_heads:
- # Interleave for MQA workaround.
- key = self.repeat_kv(key, self.num_queries_per_kv)
- value = self.repeat_kv(value, self.num_queries_per_kv)
- if self.alibi_slopes is not None:
- attn_masks = _make_alibi_bias(
- self.alibi_slopes,
- query.dtype,
- attn_metadata.seq_lens,
- make_attn_mask=True) # type: ignore
- query = query.movedim(0, query.dim() - 2)
- key = key.movedim(0, key.dim() - 2)
- value = value.movedim(0, value.dim() - 2)
- # sdpa math backend attention
- out = self.attn_func(
- query,
- key,
- value,
- prefill_meta.seq_lens,
- num_tokens,
- self.num_heads,
- self.head_size,
- self.scale,
- attn_masks,
- )
- else:
- out = self.attn_func(
- q=query,
- k=key,
- v=value,
- cu_seqlens_q=prefill_meta.seq_start_loc,
- cu_seqlens_k=prefill_meta.seq_start_loc,
- max_seqlen_q=prefill_meta.max_prefill_seq_len,
- max_seqlen_k=prefill_meta.max_prefill_seq_len,
- softmax_scale=self.scale,
- causal=True,
- window_size=self.sliding_window,
- alibi_slopes=self.alibi_slopes,
- )
- # common code for prefill
- assert output[:num_prefill_tokens].shape == out.shape
- output[:num_prefill_tokens] = out
- else:
- # prefix-enabled attention
- output[:num_prefill_tokens] = PagedAttention.forward_prefix(
- query,
- key,
- value,
- self.kv_cache_dtype,
- key_cache,
- value_cache,
- prefill_meta.block_tables,
- prefill_meta.query_start_loc,
- prefill_meta.seq_lens_tensor,
- prefill_meta.context_lens_tensor,
- prefill_meta.max_query_len,
- self.alibi_slopes,
- self.sliding_window[0],
- k_scale,
- v_scale,
- )
- if decode_meta := attn_metadata.decode_metadata:
- # Decoding run.
- output[num_prefill_tokens:] = PagedAttention.forward_decode(
- decode_query,
- key_cache,
- value_cache,
- decode_meta.block_tables,
- decode_meta.seq_lens_tensor,
- decode_meta.max_decode_seq_len,
- self.kv_cache_dtype,
- self.num_kv_heads,
- self.scale,
- self.alibi_slopes,
- k_scale,
- v_scale,
- )
- # Reshape the output tensor.
- return output.view(num_tokens, hidden_size)
- def _sdpa_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- seq_lens: List[int],
- num_tokens: int,
- num_heads: int,
- head_size: int,
- scale: float,
- attn_masks: Optional[List[torch.Tensor]] = None,
- ) -> torch.Tensor:
- start = 0
- output = torch.empty((num_tokens, num_heads, head_size),
- dtype=query.dtype,
- device=query.device)
- for i, seq_len in enumerate(seq_lens):
- end = start + seq_len
- with torch.backends.cuda.sdp_kernel(enable_math=True,
- enable_flash=False,
- enable_mem_efficient=False):
- sub_out = torch.nn.functional.scaled_dot_product_attention(
- query[:, start:end, :],
- key[:, start:end, :],
- value[:, start:end, :],
- dropout_p=0.0,
- is_causal=attn_masks is None,
- attn_mask=attn_masks[i] if attn_masks else None,
- scale=scale).movedim(query.dim() - 2, 0)
- output[start:end, :, :] = sub_out
- start = end
- return output
|