Browse Source

feat: correctly invoke prefill & decode kernels for cross-attention

AlpinDale 6 months ago
parent
commit
2105e4fd6b

+ 8 - 0
aphrodite/attention/backends/abstract.py

@@ -1,11 +1,18 @@
 from abc import ABC, abstractmethod
 from dataclasses import dataclass, fields
+from enum import Enum, auto
 from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
                     TypeVar)
 
 import torch
 
 
+class AttentionType(Enum):
+    DECODER = auto()  # Decoder attention between previous layer Q/K/V
+    ENCODER = auto()  # Encoder attention between previous layer Q/K/V
+    ENCODER_DECODER = auto()  # Attention between dec. Q and enc. K/V
+
+
 class AttentionBackend(ABC):
     """Abstract class for attention backends."""
 
@@ -128,5 +135,6 @@ class AttentionImpl(ABC, Generic[T]):
         kv_cache: torch.Tensor,
         attn_metadata: T,
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         raise NotImplementedError

+ 8 - 1
aphrodite/attention/backends/blocksparse_attn.py

@@ -5,7 +5,8 @@ import torch
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 from aphrodite.attention.ops.blocksparse_attention.interface import (
     LocalStridedBlockSparseAttn, get_head_sliding_step)
 from aphrodite.attention.ops.paged_attn import PagedAttention
@@ -324,6 +325,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
         kv_cache: torch.Tensor,
         attn_metadata: BlocksparseFlashAttentionMetadata,
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
         Args:
@@ -335,6 +337,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "BlocksparseFlashAttentionImpl")
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)

+ 8 - 1
aphrodite/attention/backends/flash_attn.py

@@ -8,7 +8,8 @@ from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
 from aphrodite import _custom_ops as ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 
 
 class FlashAttentionBackend(AttentionBackend):
@@ -258,6 +259,7 @@ class FlashAttentionImpl(AttentionImpl):
         kv_cache: torch.Tensor,
         attn_metadata: FlashAttentionMetadata,
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention.
 
@@ -270,6 +272,11 @@ class FlashAttentionImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "FlashAttentionImpl")
         # NOTE: FlashAttention does not support FP8 KV cache.
         assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
 

+ 8 - 1
aphrodite/attention/backends/flashinfer.py

@@ -15,7 +15,8 @@ import torch
 from aphrodite import _custom_ops as ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 
 
 class FlashInferBackend(AttentionBackend):
@@ -226,8 +227,14 @@ class FlashInferImpl(AttentionImpl):
         kv_cache: Optional[torch.Tensor],
         attn_metadata: FlashInferMetadata,
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         assert kv_scale == 1.0
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "FlashInferImpl")
         num_tokens, hidden_size = query.shape
         query = query.view(-1, self.num_heads, self.head_size)
         key = key.view(-1, self.num_kv_heads, self.head_size)

+ 8 - 1
aphrodite/attention/backends/ipex_attn.py

@@ -8,7 +8,8 @@ import torch
 from aphrodite._ipex_ops import ipex_ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
@@ -158,6 +159,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
         kv_cache: Optional[torch.Tensor],
         attn_metadata: IpexAttnMetadata,  # type: ignore
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with IPEX varlen_attention and PagedAttention.
         Args:
@@ -170,6 +172,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
             shape = [num_tokens, num_heads * head_size]
         """
         assert kv_scale == 1.0
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "IpexAttnBackendImpl")
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)

+ 8 - 1
aphrodite/attention/backends/pallas.py

@@ -7,7 +7,8 @@ import torch_xla.experimental.dynamo_set_buffer_donor
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 
 
 class PallasAttentionBackend(AttentionBackend):
@@ -133,6 +134,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
         kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
         attn_metadata: PallasMetadata,
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with Pallas attention.
         Args:
@@ -146,6 +148,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
             shape = [batch_size, seq_len, num_heads * head_size]
         """
         assert kv_scale == 1.0
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "PallasAttentionBackendImpl")
         batch_size, seq_len, hidden_size = query.shape
         query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
         key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

+ 8 - 1
aphrodite/attention/backends/rocm_flash_attn.py

@@ -8,7 +8,8 @@ from loguru import logger
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
@@ -298,6 +299,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         kv_cache: torch.Tensor,
         attn_metadata: ROCmFlashAttentionMetadata,
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
 
@@ -310,6 +312,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "ROCmFlashAttentionImpl")
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)

+ 8 - 1
aphrodite/attention/backends/torch_sdpa.py

@@ -8,7 +8,8 @@ from torch.nn.functional import scaled_dot_product_attention
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 from aphrodite.attention.ops.paged_attn import PagedAttentionMetadata
 from aphrodite.common.utils import is_cpu
 
@@ -146,6 +147,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
         kv_cache: Optional[torch.Tensor],
         attn_metadata: TorchSDPAMetadata,  # type: ignore
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with torch SDPA and PagedAttention.
 
@@ -158,6 +160,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "TorchSDPABackendImpl")
         assert kv_scale == 1.0
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.

+ 7 - 0
aphrodite/attention/backends/utils.py

@@ -0,0 +1,7 @@
+"""Attention backend utils"""
+
+# Error string(s) for encoder/decoder
+# unsupported attention scenarios
+
+STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
+                                 "with encoder/decoder models.")

+ 399 - 82
aphrodite/attention/backends/xformers.py

@@ -6,11 +6,13 @@ import torch
 from xformers import ops as xops
 from xformers.ops.fmha.attn_bias import (AttentionBias,
                                          BlockDiagonalCausalMask,
+                                         BlockDiagonalMask,
                                          LowerTriangularMaskWithTensorBias)
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
-                                                   AttentionMetadata)
+                                                   AttentionMetadata,
+                                                   AttentionType)
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
@@ -64,11 +66,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
     dynamically, it should be stored in tensor. The tensor has to be
     updated from `CUDAGraphRunner.forward` API.
     """
-    # (batch_size,). The sequence length per sequence. Sequence length means
-    # the computed tokens + new tokens None if it is a decoding.
-    seq_lens: Optional[List[int]]
-    # seq_lens stored as a tensor.
-    seq_lens_tensor: Optional[torch.Tensor]
 
     # |---------- N-1 iteration --------|
     # |---------------- N iteration ---------------------|
@@ -77,8 +74,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
     # |-------------------- seq_len ----------------------|
     #                                   |-- query_len ---|
 
-    # Maximum query length in the batch. None for decoding.
-    max_query_len: Optional[int]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
     # FIXME: It is for flash attn.
     # Maximum sequence length among prefill batch. 0 if there are decoding
     # requests only.
@@ -86,26 +84,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
     # Maximum sequence length among decode batch. 0 if there are prefill
     # requests only.
     max_decode_seq_len: int
-    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
-    # the batch, used to index into subquery. E.g., if the subquery length
-    # is [4, 6], it is [0, 4, 10].
-    query_start_loc: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]] = None
+
     # FIXME: It is for flash attn.
     # (batch_size + 1,). The cumulative sequence lengths of the sequences in
     # the batch, used to index into sequence. E.g., if the sequence length is
     # [4, 6], it is [0, 4, 10].
-    seq_start_loc: Optional[torch.Tensor]
+    seq_start_loc: Optional[torch.Tensor] = None
+
     # (batch_size,) A tensor of context lengths (tokens that are computed
     # so far).
-    context_lens_tensor: Optional[torch.Tensor]
+    context_lens_tensor: Optional[torch.Tensor] = None
 
-    # Whether or not if cuda graph is enabled.
-    # Cuda-graph is currently enabled for decoding only.
-    # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
-    use_cuda_graph: bool
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int] = None
+
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    query_start_loc: Optional[torch.Tensor] = None
+
+    # Self-attention prefill/decode metadata cache
     _cached_prefill_metadata: Optional["XFormersMetadata"] = None
     _cached_decode_metadata: Optional["XFormersMetadata"] = None
 
+    # Begin encoder attn & enc/dec cross-attn fields...
+
+    # Encoder sequence lengths representation
+    encoder_seq_lens: Optional[List[int]] = None
+    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
+
+    # Maximum sequence length among encoder sequences
+    max_encoder_seq_len: Optional[int] = None
+
+    # Number of tokens input to encoder
+    num_encoder_tokens: Optional[int] = None
+
+    # Cross-attention memory-mapping data structures: slot mapping
+    # and block tables
+    cross_slot_mapping: Optional[torch.Tensor] = None
+    cross_block_tables: Optional[torch.Tensor] = None
+
     def __post_init__(self):
         # Set during the execution of the first attention op.
         # It is a list because it is needed to set per prompt
@@ -113,6 +140,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
         # from xformer API.
         # will not appear in the __repr__ and __init__
         self.attn_bias: Optional[List[AttentionBias]] = None
+        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
+        self.cross_attn_bias: Optional[List[AttentionBias]] = None
+
+    @property
+    def is_all_encoder_attn_metadata_set(self):
+        '''
+        All attention metadata required for encoder attention is set.
+        '''
+        return ((self.encoder_seq_lens is not None)
+                and (self.encoder_seq_lens_tensor is not None)
+                and (self.max_encoder_seq_len is not None))
+
+    @property
+    def is_all_cross_attn_metadata_set(self):
+        '''
+        All attention metadata required for enc/dec cross-attention is set.
+
+        Superset of encoder attention required metadata.
+        '''
+        return (self.is_all_encoder_attn_metadata_set
+                and (self.cross_slot_mapping is not None)
+                and (self.cross_block_tables is not None))
 
     @property
     def prefill_metadata(self) -> Optional["XFormersMetadata"]:
@@ -120,30 +169,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
             return None
 
         if self._cached_prefill_metadata is not None:
+            # Recover cached prefill-phase attention
+            # metadata structure
             return self._cached_prefill_metadata
 
-        assert self.seq_lens is not None
-        assert self.seq_lens_tensor is not None
-        assert self.query_start_loc is not None
-        assert self.context_lens_tensor is not None
-        assert self.block_tables is not None
-
+        assert ((self.seq_lens is not None)
+                or (self.encoder_seq_lens is not None))
+        assert ((self.seq_lens_tensor is not None)
+                or (self.encoder_seq_lens_tensor is not None))
+
+        # Compute some attn_metadata fields which default to None
+        query_start_loc = (None if self.query_start_loc is None else
+                           self.query_start_loc[:self.num_prefills + 1])
+        slot_mapping = (None if self.slot_mapping is None else
+                        self.slot_mapping[:self.num_prefill_tokens])
+        seq_lens = (None if self.seq_lens is None else
+                    self.seq_lens[:self.num_prefills])
+        seq_lens_tensor = (None if self.seq_lens_tensor is None else
+                           self.seq_lens_tensor[:self.num_prefills])
+        context_lens_tensor = (None if self.context_lens_tensor is None else
+                               self.context_lens_tensor[:self.num_prefills])
+        block_tables = (None if self.block_tables is None else
+                        self.block_tables[:self.num_prefills])
+
+        # Construct & cache prefill-phase attention metadata structure
         self._cached_prefill_metadata = XFormersMetadata(
             num_prefills=self.num_prefills,
             num_prefill_tokens=self.num_prefill_tokens,
             num_decode_tokens=0,
-            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
-            seq_lens=self.seq_lens[:self.num_prefills],
-            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            slot_mapping=slot_mapping,
+            seq_lens=seq_lens,
+            seq_lens_tensor=seq_lens_tensor,
             max_query_len=self.max_query_len,
             max_prefill_seq_len=self.max_prefill_seq_len,
             max_decode_seq_len=0,
-            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
-            seq_start_loc=None,
-            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
-            block_tables=self.block_tables[:self.num_prefills],
+            query_start_loc=query_start_loc,
+            context_lens_tensor=context_lens_tensor,
+            block_tables=block_tables,
             use_cuda_graph=False,
-        )
+            # Begin encoder & cross attn fields below...
+            encoder_seq_lens=self.encoder_seq_lens,
+            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
+            max_encoder_seq_len=self.max_encoder_seq_len,
+            cross_slot_mapping=self.cross_slot_mapping,
+            cross_block_tables=self.cross_block_tables)
         return self._cached_prefill_metadata
 
     @property
@@ -152,29 +221,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
             return None
 
         if self._cached_decode_metadata is not None:
+            # Recover cached decode-phase attention
+            # metadata structure
             return self._cached_decode_metadata
-        assert self.block_tables is not None
-        assert self.seq_lens_tensor is not None
-
+        assert ((self.seq_lens_tensor is not None)
+                or (self.encoder_seq_lens_tensor is not None))
+
+        # Compute some attn_metadata fields which default to None
+        slot_mapping = (None if self.slot_mapping is None else
+                        self.slot_mapping[self.num_prefill_tokens:])
+        seq_lens_tensor = (None if self.seq_lens_tensor is None else
+                           self.seq_lens_tensor[self.num_prefills:])
+        block_tables = (None if self.block_tables is None else
+                        self.block_tables[self.num_prefills:])
+
+        # Construct & cache decode-phase attention metadata structure
         self._cached_decode_metadata = XFormersMetadata(
             num_prefills=0,
             num_prefill_tokens=0,
             num_decode_tokens=self.num_decode_tokens,
-            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
-            seq_lens=None,
-            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
-            max_query_len=None,
+            slot_mapping=slot_mapping,
+            seq_lens_tensor=seq_lens_tensor,
             max_prefill_seq_len=0,
             max_decode_seq_len=self.max_decode_seq_len,
-            query_start_loc=None,
-            seq_start_loc=None,
-            context_lens_tensor=None,
-            block_tables=self.block_tables[self.num_prefills:],
+            block_tables=block_tables,
             use_cuda_graph=self.use_cuda_graph,
-        )
+            # Begin encoder & cross attn fields below...
+            encoder_seq_lens=self.encoder_seq_lens,
+            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
+            max_encoder_seq_len=self.max_encoder_seq_len,
+            cross_slot_mapping=self.cross_slot_mapping,
+            cross_block_tables=self.cross_block_tables)
         return self._cached_decode_metadata
 
 
+def _get_attn_bias(
+    attn_metadata: XFormersMetadata,
+    attn_type: AttentionType,
+) -> Optional[AttentionBias]:
+    '''
+    Extract appropriate attention bias from attention metadata
+    according to attention type.
+
+    Arguments:
+
+    * attn_metadata: Attention metadata structure associated with attention
+    * attn_type: encoder attention, decoder self-attention,
+                 encoder/decoder cross-attention
+
+    Returns:
+    * Appropriate attention bias value given the attention type
+    '''
+
+    if attn_type == AttentionType.DECODER:
+        return attn_metadata.attn_bias
+    elif attn_type == AttentionType.ENCODER:
+        return attn_metadata.encoder_attn_bias
+    else:
+        # attn_type == AttentionType.ENCODER_DECODER
+        return attn_metadata.cross_attn_bias
+
+
+def _set_attn_bias(
+    attn_metadata: XFormersMetadata,
+    attn_bias: List[Optional[AttentionBias]],
+    attn_type: AttentionType,
+) -> None:
+    '''
+    Update appropriate attention bias field of attention metadata,
+    according to attention type.
+
+    Arguments:
+
+    * attn_metadata: Attention metadata structure associated with attention
+    * attn_bias: The desired attention bias value
+    * attn_type: encoder attention, decoder self-attention,
+                 encoder/decoder cross-attention
+    '''
+
+    if attn_type == AttentionType.DECODER:
+        attn_metadata.attn_bias = attn_bias
+    elif attn_type == AttentionType.ENCODER:
+        attn_metadata.encoder_attn_bias = attn_bias
+    elif attn_type == AttentionType.ENCODER_DECODER:
+        attn_metadata.cross_attn_bias = attn_bias
+    else:
+        raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
+
+def _get_seq_len_block_table_args(
+    attn_metadata: XFormersMetadata,
+    is_prompt: bool,
+    attn_type: AttentionType,
+) -> tuple:
+    '''
+    The particular choice of sequence-length- and block-table-related
+    attributes which should be extracted from attn_metadata is dependent
+    on the type of attention operation.
+
+    Decoder attn -> select entirely decoder self-attention-related fields
+    Encoder/decoder cross-attn -> select encoder sequence lengths & 
+                                  cross-attn block-tables fields
+    Encoder attn -> select encoder sequence lengths fields & no block tables
+    
+    Arguments:
+
+    * attn_metadata: Attention metadata structure associated with attention op
+    * is_prompt: True if prefill, False otherwise
+    * attn_type: encoder attention, decoder self-attention,
+                 encoder/decoder cross-attention
+
+    Returns:
+
+    * Appropriate sequence-lengths tensor
+    * Appropriate max sequence-length scalar
+    * Appropriate block tables (or None)
+    '''
+
+    if attn_type == AttentionType.DECODER:
+        # Decoder self-attention
+        # Choose max_seq_len based on whether we are in prompt_run
+        if is_prompt:
+            max_seq_len = attn_metadata.max_prefill_seq_len
+        else:
+            max_seq_len = attn_metadata.max_decode_seq_len
+        return (attn_metadata.seq_lens_tensor, max_seq_len,
+                attn_metadata.block_tables)
+    elif attn_type == AttentionType.ENCODER_DECODER:
+        # Enc/dec cross-attention KVs match encoder sequence length;
+        # cross-attention utilizes special "cross" block tables
+        return (attn_metadata.encoder_seq_lens_tensor,
+                attn_metadata.max_encoder_seq_len,
+                attn_metadata.cross_block_tables)
+    elif attn_type == AttentionType.ENCODER:
+        # No block tables associated with encoder attention
+        return (attn_metadata.encoder_seq_lens_tensor,
+                attn_metadata.max_encoder_seq_len, None)
+    else:
+        raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
+
 class XFormersImpl(AttentionImpl[XFormersMetadata]):
     """
     If the input tensors contain prompt tokens, the layout is as follows:
@@ -213,7 +399,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
         assert blocksparse_params is None, ValueError(
-            "XFormers does not support block-sparse attention.")
+            "XFormer does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
@@ -236,51 +422,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
     def forward(
         self,
         query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
+        key: Optional[torch.Tensor],
+        value: Optional[torch.Tensor],
         kv_cache: Optional[torch.Tensor],
         attn_metadata: "XFormersMetadata",
         kv_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with xFormers and PagedAttention.
 
+        For decoder-only models: query, key and value must be non-None.
+
+        For encoder/decoder models:
+        * XFormersImpl.forward() may be invoked for both self- and cross-
+          attention layers.
+        * For self-attention: query, key and value must be non-None.
+        * For cross-attention:
+            * Query must be non-None
+            * During prefill, key and value must be non-None; key and value
+              get cached for use during decode.
+            * During decode, key and value may be None, since:
+              (1) key and value tensors were cached during prefill, and
+              (2) cross-attention key and value tensors do not grow during
+                  decode
+        
+        A note on how the attn_type (attention type enum) argument impacts
+        attention forward() behavior:
+    
+            * DECODER: normal decoder-only behavior;
+                use decoder self-attention block table
+            * ENCODER: no KV caching; pass encoder sequence
+                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
+                max_encoder_seq_len) to kernel, in lieu of decoder
+                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
+            * ENCODER_DECODER: cross-attention behavior;
+                use cross-attention block table for caching KVs derived
+                from encoder hidden states; since KV sequence lengths
+                will match encoder sequence lengths, pass encoder sequence
+                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
+                max_encoder_seq_len)
+    
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
             value: shape = [num_tokens, num_kv_heads * head_size]
             kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
             attn_metadata: Metadata for attention.
+            attn_type: Select attention type, between encoder attention,
+                       decoder self-attention, or encoder/decoder cross-
+                       attention. Defaults to decoder self-attention,
+                       which is the Aphrodite default generally
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
 
-        if kv_cache is not None:
+        # Check that appropriate attention metadata attributes are
+        # selected for the desired attention type
+        if (attn_type == AttentionType.ENCODER
+                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
+            raise AttributeError("Encoder attention requires setting "
+                                 "encoder metadata attributes.")
+        elif (attn_type == AttentionType.ENCODER_DECODER
+              and (not attn_metadata.is_all_cross_attn_metadata_set)):
+            raise AttributeError("Encoder/decoder cross-attention "
+                                 "requires setting cross-attention "
+                                 "metadata attributes.")
+
+        query = query.view(-1, self.num_heads, self.head_size)
+        if key is not None:
+            assert value is not None
+            key = key.view(-1, self.num_kv_heads, self.head_size)
+            value = value.view(-1, self.num_kv_heads, self.head_size)
+        else:
+            assert value is None
+
+        # Self-attention vs. cross-attention will impact
+        # which KV cache memory-mapping & which
+        # seqlen datastructures we utilize
+
+        if (attn_type != AttentionType.ENCODER and kv_cache is not None):
+            # KV-cache during decoder-self- or
+            # encoder-decoder-cross-attention, but not
+            # during encoder attention.
+            #
+            # Even if there are no new key/value pairs to cache,
+            # we still need to break out key_cache and value_cache
+            # i.e. for later use by paged attention
             key_cache, value_cache = PagedAttention.split_kv_cache(
                 kv_cache, self.num_kv_heads, self.head_size)
 
-            # Reshape the input keys and values and store them in the cache.
-            # If kv_cache is not provided, the new key and value tensors are
-            # not cached. This happens during the initial memory profiling run.
-            PagedAttention.write_to_paged_cache(key, value, key_cache,
-                                                value_cache,
-                                                attn_metadata.slot_mapping,
-                                                self.kv_cache_dtype, kv_scale)
-
-        num_prefill_tokens = attn_metadata.num_prefill_tokens
-        num_decode_tokens = attn_metadata.num_decode_tokens
-        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
-        assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+            if (key is not None) and (value is not None):
+
+                if attn_type == AttentionType.ENCODER_DECODER:
+                    # Update cross-attention KV cache (prefill-only)
+                    # During cross-attention decode, key & value will be None,
+                    # preventing this IF-statement branch from running
+                    updated_slot_mapping = attn_metadata.cross_slot_mapping
+                else:
+                    # Update self-attention KV cache (prefill/decode)
+                    updated_slot_mapping = attn_metadata.slot_mapping
+
+                # Reshape the input keys and values and store them in the cache.
+                # If kv_cache is not provided, the new key and value tensors are
+                # not cached. This happens during the initial memory
+                # profiling run.
+                PagedAttention.write_to_paged_cache(key, value, key_cache,
+                                                    value_cache,
+                                                    updated_slot_mapping,
+                                                    self.kv_cache_dtype,
+                                                    kv_scale)
+
+        if attn_type != AttentionType.ENCODER:
+            # Decoder self-attention supports chunked prefill.
+            # Encoder/decoder cross-attention requires no chunked
+            # prefill (100% prefill or 100% decode tokens, no mix)
+            num_prefill_tokens = attn_metadata.num_prefill_tokens
+            num_decode_tokens = attn_metadata.num_decode_tokens
+        else:
+            # Encoder attention - chunked prefill is not applicable;
+            # derive token-count from query shape & and treat them
+            # as 100% prefill tokens
+            assert attn_metadata.num_encoder_tokens is not None
+            num_prefill_tokens = attn_metadata.num_encoder_tokens
+            num_decode_tokens = 0
+
+        if attn_type == AttentionType.DECODER:
+            # Only enforce this shape-constraint for decoder
+            # self-attention
+            assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+            assert value.shape[0] == num_prefill_tokens + num_decode_tokens
 
         output = torch.empty_like(query)
         # Query for decode. KV is not needed because it is already cached.
         decode_query = query[num_prefill_tokens:]
         # QKV for prefill.
         query = query[:num_prefill_tokens]
-        key = key[:num_prefill_tokens]
-        value = value[:num_prefill_tokens]
+        if key is not None and value is not None:
+            key = key[:num_prefill_tokens]
+            value = value[:num_prefill_tokens]
 
         assert query.shape[0] == num_prefill_tokens
         assert decode_query.shape[0] == num_decode_tokens
@@ -292,10 +571,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
                 # block tables are empty if the prompt does not have a cached
                 # prefix.
                 out = self._run_memory_efficient_xformers_forward(
-                    query, key, value, prefill_meta)
+                    query, key, value, prefill_meta, attn_type=attn_type)
                 assert out.shape == output[:num_prefill_tokens].shape
                 output[:num_prefill_tokens] = out
             else:
+
+                assert prefill_meta.query_start_loc is not None
+                assert prefill_meta.max_query_len is not None
+
                 # prefix-enabled attention
                 # TODO: this triton kernel has regression issue (broke) to
                 # deal with different data types between KV and FP8 KV cache,
@@ -318,13 +601,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
                 output[:num_prefill_tokens] = out
 
         if decode_meta := attn_metadata.decode_metadata:
+
+            (
+                seq_lens_arg,
+                max_seq_len_arg,
+                block_tables_arg,
+            ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
+
             output[num_prefill_tokens:] = PagedAttention.forward_decode(
                 decode_query,
                 key_cache,
                 value_cache,
-                decode_meta.block_tables,
-                decode_meta.seq_lens_tensor,
-                decode_meta.max_decode_seq_len,
+                block_tables_arg,
+                seq_lens_arg,
+                max_seq_len_arg,
                 self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
@@ -341,6 +631,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         key: torch.Tensor,
         value: torch.Tensor,
         attn_metadata: XFormersMetadata,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Attention for 1D query of multiple prompts. Multiple prompt
         tokens are flattened in to `query` input.
@@ -354,8 +645,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
             key: shape = [num_prefill_tokens, num_kv_heads, head_size]
             value: shape = [num_prefill_tokens, num_kv_heads, head_size]
             attn_metadata: Metadata for attention.
+            attn_type: Select attention type, between encoder attention,
+                       decoder self-attention, or encoder/decoder cross-
+                       attention. Defaults to decoder self-attention,
+                       which is the Aphrodite default generally
         """
-        assert attn_metadata.seq_lens is not None
+
         original_query = query
         if self.num_kv_heads != self.num_heads:
             # GQA/MQA requires the shape [B, M, G, H, K].
@@ -373,18 +668,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         # Set attention bias if not provided. This typically happens at
         # the very attention layer of every iteration.
         # FIXME: This is a hack.
-        if attn_metadata.attn_bias is None:
+        attn_bias = _get_attn_bias(attn_metadata, attn_type)
+        if attn_bias is None:
             if self.alibi_slopes is None:
-                attn_bias = BlockDiagonalCausalMask.from_seqlens(
-                    attn_metadata.seq_lens)
+                if (attn_type == AttentionType.ENCODER_DECODER):
+                    assert attn_metadata.seq_lens is not None
+                    assert attn_metadata.encoder_seq_lens is not None
+
+                    # Default enc/dec cross-attention mask is non-causal
+                    attn_bias = BlockDiagonalMask.from_seqlens(
+                        attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
+                elif attn_type == AttentionType.ENCODER:
+                    assert attn_metadata.encoder_seq_lens is not None
+
+                    # Default encoder self-attention mask is non-causal
+                    attn_bias = BlockDiagonalMask.from_seqlens(
+                        attn_metadata.encoder_seq_lens)
+                else:
+                    assert attn_metadata.seq_lens is not None
+
+                    # Default decoder self-attention mask is causal
+                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
+                        attn_metadata.seq_lens)
                 if self.sliding_window is not None:
                     attn_bias = attn_bias.make_local_attention(
                         self.sliding_window)
-                attn_metadata.attn_bias = [attn_bias]
+                attn_bias = [attn_bias]
             else:
-                attn_metadata.attn_bias = _make_alibi_bias(
-                    self.alibi_slopes, self.num_kv_heads, query.dtype,
-                    attn_metadata.seq_lens)
+                assert attn_metadata.seq_lens is not None
+                attn_bias = _make_alibi_bias(self.alibi_slopes,
+                                             self.num_kv_heads, query.dtype,
+                                             attn_metadata.seq_lens)
+
+            _set_attn_bias(attn_metadata, attn_bias, attn_type)
 
         # No alibi slopes.
         # TODO: Too many view operations. Let's try to reduce
@@ -398,7 +714,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
                 query,
                 key,
                 value,
-                attn_bias=attn_metadata.attn_bias[0],
+                attn_bias=attn_bias[0],
                 p=0.0,
                 scale=self.scale)
             return out.view_as(original_query)
@@ -407,6 +723,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         # FIXME: Because xformers does not support dynamic sequence
         # lengths with custom attention bias, we process each prompt one by
         # one. This is inefficient, especially when we have many short prompts.
+        assert attn_metadata.seq_lens is not None
         output = torch.empty_like(original_query)
         start = 0
         for i, seq_len in enumerate(attn_metadata.seq_lens):
@@ -415,7 +732,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
                 query[None, start:end],
                 key[None, start:end],
                 value[None, start:end],
-                attn_bias=attn_metadata.attn_bias[i],
+                attn_bias=attn_bias[i],
                 p=0.0,
                 scale=self.scale)
             # TODO: Unnecessary copy. Optimize.
@@ -429,11 +746,11 @@ def _make_alibi_bias(
     num_kv_heads: int,
     dtype: torch.dtype,
     seq_lens: List[int],
-) -> LowerTriangularMaskWithTensorBias:
-    attn_biases = []
+) -> List[AttentionBias]:
+    attn_biases: List[AttentionBias] = []
     for seq_len in seq_lens:
         bias = torch.arange(seq_len, dtype=dtype)
-        # NOTE(zhuohan): HF uses
+        # NOTE: HF uses
         #     `bias = bias[None, :].repeat(seq_len, 1)`
         # here. We find that both biases give the same results, but
         # the bias below more accurately follows the original ALiBi

+ 11 - 3
aphrodite/attention/layer.py

@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional
 import torch
 import torch.nn as nn
 
-from aphrodite.attention.backends.abstract import AttentionMetadata
+from aphrodite.attention.backends.abstract import (AttentionMetadata,
+                                                   AttentionType)
 from aphrodite.attention.selector import get_attn_backend
 from aphrodite.common.config import CacheConfig
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -89,9 +90,16 @@ class Attention(nn.Module):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
-        return self.impl.forward(query, key, value, kv_cache, attn_metadata,
-                                 self._kv_scale)
+
+        return self.impl.forward(query,
+                                 key,
+                                 value,
+                                 kv_cache,
+                                 attn_metadata,
+                                 self._kv_scale,
+                                 attn_type=attn_type)
 
     def extra_repr(self) -> str:
         s = f"head_size={self.impl.head_size}"  # type: ignore