|
@@ -6,11 +6,13 @@ import torch
|
|
|
from xformers import ops as xops
|
|
|
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
|
|
BlockDiagonalCausalMask,
|
|
|
+ BlockDiagonalMask,
|
|
|
LowerTriangularMaskWithTensorBias)
|
|
|
|
|
|
from aphrodite.attention.backends.abstract import (AttentionBackend,
|
|
|
AttentionImpl,
|
|
|
- AttentionMetadata)
|
|
|
+ AttentionMetadata,
|
|
|
+ AttentionType)
|
|
|
from aphrodite.attention.ops.paged_attn import (PagedAttention,
|
|
|
PagedAttentionMetadata)
|
|
|
|
|
@@ -64,11 +66,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
|
"""
|
|
|
- # (batch_size,). The sequence length per sequence. Sequence length means
|
|
|
- # the computed tokens + new tokens None if it is a decoding.
|
|
|
- seq_lens: Optional[List[int]]
|
|
|
- # seq_lens stored as a tensor.
|
|
|
- seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
|
|
# |---------- N-1 iteration --------|
|
|
|
# |---------------- N iteration ---------------------|
|
|
@@ -77,8 +74,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
# |-------------------- seq_len ----------------------|
|
|
|
# |-- query_len ---|
|
|
|
|
|
|
- # Maximum query length in the batch. None for decoding.
|
|
|
- max_query_len: Optional[int]
|
|
|
+ # seq_lens stored as a tensor.
|
|
|
+ seq_lens_tensor: Optional[torch.Tensor]
|
|
|
+
|
|
|
# FIXME: It is for flash attn.
|
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
|
# requests only.
|
|
@@ -86,26 +84,55 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
|
# requests only.
|
|
|
max_decode_seq_len: int
|
|
|
- # (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
|
- # the batch, used to index into subquery. E.g., if the subquery length
|
|
|
- # is [4, 6], it is [0, 4, 10].
|
|
|
- query_start_loc: Optional[torch.Tensor]
|
|
|
+
|
|
|
+ # Whether or not if cuda graph is enabled.
|
|
|
+ # Cuda-graph is currently enabled for decoding only.
|
|
|
+ # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
|
|
|
+ use_cuda_graph: bool
|
|
|
+
|
|
|
+ # (batch_size,). The sequence length per sequence. Sequence length means
|
|
|
+ # the computed tokens + new tokens None if it is a decoding.
|
|
|
+ seq_lens: Optional[List[int]] = None
|
|
|
+
|
|
|
# FIXME: It is for flash attn.
|
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
|
# [4, 6], it is [0, 4, 10].
|
|
|
- seq_start_loc: Optional[torch.Tensor]
|
|
|
+ seq_start_loc: Optional[torch.Tensor] = None
|
|
|
+
|
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
|
# so far).
|
|
|
- context_lens_tensor: Optional[torch.Tensor]
|
|
|
+ context_lens_tensor: Optional[torch.Tensor] = None
|
|
|
|
|
|
- # Whether or not if cuda graph is enabled.
|
|
|
- # Cuda-graph is currently enabled for decoding only.
|
|
|
- # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
|
|
|
- use_cuda_graph: bool
|
|
|
+ # Maximum query length in the batch. None for decoding.
|
|
|
+ max_query_len: Optional[int] = None
|
|
|
+
|
|
|
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
|
+ # the batch, used to index into subquery. E.g., if the subquery length
|
|
|
+ # is [4, 6], it is [0, 4, 10].
|
|
|
+ query_start_loc: Optional[torch.Tensor] = None
|
|
|
+
|
|
|
+ # Self-attention prefill/decode metadata cache
|
|
|
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
|
|
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
|
|
|
|
|
+ # Begin encoder attn & enc/dec cross-attn fields...
|
|
|
+
|
|
|
+ # Encoder sequence lengths representation
|
|
|
+ encoder_seq_lens: Optional[List[int]] = None
|
|
|
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
|
|
+
|
|
|
+ # Maximum sequence length among encoder sequences
|
|
|
+ max_encoder_seq_len: Optional[int] = None
|
|
|
+
|
|
|
+ # Number of tokens input to encoder
|
|
|
+ num_encoder_tokens: Optional[int] = None
|
|
|
+
|
|
|
+ # Cross-attention memory-mapping data structures: slot mapping
|
|
|
+ # and block tables
|
|
|
+ cross_slot_mapping: Optional[torch.Tensor] = None
|
|
|
+ cross_block_tables: Optional[torch.Tensor] = None
|
|
|
+
|
|
|
def __post_init__(self):
|
|
|
# Set during the execution of the first attention op.
|
|
|
# It is a list because it is needed to set per prompt
|
|
@@ -113,6 +140,28 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
# from xformer API.
|
|
|
# will not appear in the __repr__ and __init__
|
|
|
self.attn_bias: Optional[List[AttentionBias]] = None
|
|
|
+ self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
|
|
+ self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_all_encoder_attn_metadata_set(self):
|
|
|
+ '''
|
|
|
+ All attention metadata required for encoder attention is set.
|
|
|
+ '''
|
|
|
+ return ((self.encoder_seq_lens is not None)
|
|
|
+ and (self.encoder_seq_lens_tensor is not None)
|
|
|
+ and (self.max_encoder_seq_len is not None))
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_all_cross_attn_metadata_set(self):
|
|
|
+ '''
|
|
|
+ All attention metadata required for enc/dec cross-attention is set.
|
|
|
+
|
|
|
+ Superset of encoder attention required metadata.
|
|
|
+ '''
|
|
|
+ return (self.is_all_encoder_attn_metadata_set
|
|
|
+ and (self.cross_slot_mapping is not None)
|
|
|
+ and (self.cross_block_tables is not None))
|
|
|
|
|
|
@property
|
|
|
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
|
@@ -120,30 +169,50 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
return None
|
|
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
|
+ # Recover cached prefill-phase attention
|
|
|
+ # metadata structure
|
|
|
return self._cached_prefill_metadata
|
|
|
|
|
|
- assert self.seq_lens is not None
|
|
|
- assert self.seq_lens_tensor is not None
|
|
|
- assert self.query_start_loc is not None
|
|
|
- assert self.context_lens_tensor is not None
|
|
|
- assert self.block_tables is not None
|
|
|
-
|
|
|
+ assert ((self.seq_lens is not None)
|
|
|
+ or (self.encoder_seq_lens is not None))
|
|
|
+ assert ((self.seq_lens_tensor is not None)
|
|
|
+ or (self.encoder_seq_lens_tensor is not None))
|
|
|
+
|
|
|
+ # Compute some attn_metadata fields which default to None
|
|
|
+ query_start_loc = (None if self.query_start_loc is None else
|
|
|
+ self.query_start_loc[:self.num_prefills + 1])
|
|
|
+ slot_mapping = (None if self.slot_mapping is None else
|
|
|
+ self.slot_mapping[:self.num_prefill_tokens])
|
|
|
+ seq_lens = (None if self.seq_lens is None else
|
|
|
+ self.seq_lens[:self.num_prefills])
|
|
|
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
|
+ self.seq_lens_tensor[:self.num_prefills])
|
|
|
+ context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
|
+ self.context_lens_tensor[:self.num_prefills])
|
|
|
+ block_tables = (None if self.block_tables is None else
|
|
|
+ self.block_tables[:self.num_prefills])
|
|
|
+
|
|
|
+ # Construct & cache prefill-phase attention metadata structure
|
|
|
self._cached_prefill_metadata = XFormersMetadata(
|
|
|
num_prefills=self.num_prefills,
|
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
|
num_decode_tokens=0,
|
|
|
- slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
|
|
- seq_lens=self.seq_lens[:self.num_prefills],
|
|
|
- seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
|
|
+ slot_mapping=slot_mapping,
|
|
|
+ seq_lens=seq_lens,
|
|
|
+ seq_lens_tensor=seq_lens_tensor,
|
|
|
max_query_len=self.max_query_len,
|
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
|
max_decode_seq_len=0,
|
|
|
- query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
|
|
- seq_start_loc=None,
|
|
|
- context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
|
|
- block_tables=self.block_tables[:self.num_prefills],
|
|
|
+ query_start_loc=query_start_loc,
|
|
|
+ context_lens_tensor=context_lens_tensor,
|
|
|
+ block_tables=block_tables,
|
|
|
use_cuda_graph=False,
|
|
|
- )
|
|
|
+ # Begin encoder & cross attn fields below...
|
|
|
+ encoder_seq_lens=self.encoder_seq_lens,
|
|
|
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
|
+ max_encoder_seq_len=self.max_encoder_seq_len,
|
|
|
+ cross_slot_mapping=self.cross_slot_mapping,
|
|
|
+ cross_block_tables=self.cross_block_tables)
|
|
|
return self._cached_prefill_metadata
|
|
|
|
|
|
@property
|
|
@@ -152,29 +221,146 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|
|
return None
|
|
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
|
+ # Recover cached decode-phase attention
|
|
|
+ # metadata structure
|
|
|
return self._cached_decode_metadata
|
|
|
- assert self.block_tables is not None
|
|
|
- assert self.seq_lens_tensor is not None
|
|
|
-
|
|
|
+ assert ((self.seq_lens_tensor is not None)
|
|
|
+ or (self.encoder_seq_lens_tensor is not None))
|
|
|
+
|
|
|
+ # Compute some attn_metadata fields which default to None
|
|
|
+ slot_mapping = (None if self.slot_mapping is None else
|
|
|
+ self.slot_mapping[self.num_prefill_tokens:])
|
|
|
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
|
+ self.seq_lens_tensor[self.num_prefills:])
|
|
|
+ block_tables = (None if self.block_tables is None else
|
|
|
+ self.block_tables[self.num_prefills:])
|
|
|
+
|
|
|
+ # Construct & cache decode-phase attention metadata structure
|
|
|
self._cached_decode_metadata = XFormersMetadata(
|
|
|
num_prefills=0,
|
|
|
num_prefill_tokens=0,
|
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
|
- slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
|
|
- seq_lens=None,
|
|
|
- seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
|
|
- max_query_len=None,
|
|
|
+ slot_mapping=slot_mapping,
|
|
|
+ seq_lens_tensor=seq_lens_tensor,
|
|
|
max_prefill_seq_len=0,
|
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
|
- query_start_loc=None,
|
|
|
- seq_start_loc=None,
|
|
|
- context_lens_tensor=None,
|
|
|
- block_tables=self.block_tables[self.num_prefills:],
|
|
|
+ block_tables=block_tables,
|
|
|
use_cuda_graph=self.use_cuda_graph,
|
|
|
- )
|
|
|
+ # Begin encoder & cross attn fields below...
|
|
|
+ encoder_seq_lens=self.encoder_seq_lens,
|
|
|
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
|
|
+ max_encoder_seq_len=self.max_encoder_seq_len,
|
|
|
+ cross_slot_mapping=self.cross_slot_mapping,
|
|
|
+ cross_block_tables=self.cross_block_tables)
|
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
|
|
|
+def _get_attn_bias(
|
|
|
+ attn_metadata: XFormersMetadata,
|
|
|
+ attn_type: AttentionType,
|
|
|
+) -> Optional[AttentionBias]:
|
|
|
+ '''
|
|
|
+ Extract appropriate attention bias from attention metadata
|
|
|
+ according to attention type.
|
|
|
+
|
|
|
+ Arguments:
|
|
|
+
|
|
|
+ * attn_metadata: Attention metadata structure associated with attention
|
|
|
+ * attn_type: encoder attention, decoder self-attention,
|
|
|
+ encoder/decoder cross-attention
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ * Appropriate attention bias value given the attention type
|
|
|
+ '''
|
|
|
+
|
|
|
+ if attn_type == AttentionType.DECODER:
|
|
|
+ return attn_metadata.attn_bias
|
|
|
+ elif attn_type == AttentionType.ENCODER:
|
|
|
+ return attn_metadata.encoder_attn_bias
|
|
|
+ else:
|
|
|
+ # attn_type == AttentionType.ENCODER_DECODER
|
|
|
+ return attn_metadata.cross_attn_bias
|
|
|
+
|
|
|
+
|
|
|
+def _set_attn_bias(
|
|
|
+ attn_metadata: XFormersMetadata,
|
|
|
+ attn_bias: List[Optional[AttentionBias]],
|
|
|
+ attn_type: AttentionType,
|
|
|
+) -> None:
|
|
|
+ '''
|
|
|
+ Update appropriate attention bias field of attention metadata,
|
|
|
+ according to attention type.
|
|
|
+
|
|
|
+ Arguments:
|
|
|
+
|
|
|
+ * attn_metadata: Attention metadata structure associated with attention
|
|
|
+ * attn_bias: The desired attention bias value
|
|
|
+ * attn_type: encoder attention, decoder self-attention,
|
|
|
+ encoder/decoder cross-attention
|
|
|
+ '''
|
|
|
+
|
|
|
+ if attn_type == AttentionType.DECODER:
|
|
|
+ attn_metadata.attn_bias = attn_bias
|
|
|
+ elif attn_type == AttentionType.ENCODER:
|
|
|
+ attn_metadata.encoder_attn_bias = attn_bias
|
|
|
+ elif attn_type == AttentionType.ENCODER_DECODER:
|
|
|
+ attn_metadata.cross_attn_bias = attn_bias
|
|
|
+ else:
|
|
|
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
|
|
+
|
|
|
+
|
|
|
+def _get_seq_len_block_table_args(
|
|
|
+ attn_metadata: XFormersMetadata,
|
|
|
+ is_prompt: bool,
|
|
|
+ attn_type: AttentionType,
|
|
|
+) -> tuple:
|
|
|
+ '''
|
|
|
+ The particular choice of sequence-length- and block-table-related
|
|
|
+ attributes which should be extracted from attn_metadata is dependent
|
|
|
+ on the type of attention operation.
|
|
|
+
|
|
|
+ Decoder attn -> select entirely decoder self-attention-related fields
|
|
|
+ Encoder/decoder cross-attn -> select encoder sequence lengths &
|
|
|
+ cross-attn block-tables fields
|
|
|
+ Encoder attn -> select encoder sequence lengths fields & no block tables
|
|
|
+
|
|
|
+ Arguments:
|
|
|
+
|
|
|
+ * attn_metadata: Attention metadata structure associated with attention op
|
|
|
+ * is_prompt: True if prefill, False otherwise
|
|
|
+ * attn_type: encoder attention, decoder self-attention,
|
|
|
+ encoder/decoder cross-attention
|
|
|
+
|
|
|
+ Returns:
|
|
|
+
|
|
|
+ * Appropriate sequence-lengths tensor
|
|
|
+ * Appropriate max sequence-length scalar
|
|
|
+ * Appropriate block tables (or None)
|
|
|
+ '''
|
|
|
+
|
|
|
+ if attn_type == AttentionType.DECODER:
|
|
|
+ # Decoder self-attention
|
|
|
+ # Choose max_seq_len based on whether we are in prompt_run
|
|
|
+ if is_prompt:
|
|
|
+ max_seq_len = attn_metadata.max_prefill_seq_len
|
|
|
+ else:
|
|
|
+ max_seq_len = attn_metadata.max_decode_seq_len
|
|
|
+ return (attn_metadata.seq_lens_tensor, max_seq_len,
|
|
|
+ attn_metadata.block_tables)
|
|
|
+ elif attn_type == AttentionType.ENCODER_DECODER:
|
|
|
+ # Enc/dec cross-attention KVs match encoder sequence length;
|
|
|
+ # cross-attention utilizes special "cross" block tables
|
|
|
+ return (attn_metadata.encoder_seq_lens_tensor,
|
|
|
+ attn_metadata.max_encoder_seq_len,
|
|
|
+ attn_metadata.cross_block_tables)
|
|
|
+ elif attn_type == AttentionType.ENCODER:
|
|
|
+ # No block tables associated with encoder attention
|
|
|
+ return (attn_metadata.encoder_seq_lens_tensor,
|
|
|
+ attn_metadata.max_encoder_seq_len, None)
|
|
|
+ else:
|
|
|
+ raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
|
|
+
|
|
|
+
|
|
|
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
"""
|
|
|
If the input tensors contain prompt tokens, the layout is as follows:
|
|
@@ -213,7 +399,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
|
) -> None:
|
|
|
assert blocksparse_params is None, ValueError(
|
|
|
- "XFormers does not support block-sparse attention.")
|
|
|
+ "XFormer does not support block-sparse attention.")
|
|
|
self.num_heads = num_heads
|
|
|
self.head_size = head_size
|
|
|
self.scale = float(scale)
|
|
@@ -236,51 +422,144 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
def forward(
|
|
|
self,
|
|
|
query: torch.Tensor,
|
|
|
- key: torch.Tensor,
|
|
|
- value: torch.Tensor,
|
|
|
+ key: Optional[torch.Tensor],
|
|
|
+ value: Optional[torch.Tensor],
|
|
|
kv_cache: Optional[torch.Tensor],
|
|
|
attn_metadata: "XFormersMetadata",
|
|
|
kv_scale: float = 1.0,
|
|
|
+ attn_type: AttentionType = AttentionType.DECODER,
|
|
|
) -> torch.Tensor:
|
|
|
"""Forward pass with xFormers and PagedAttention.
|
|
|
|
|
|
+ For decoder-only models: query, key and value must be non-None.
|
|
|
+
|
|
|
+ For encoder/decoder models:
|
|
|
+ * XFormersImpl.forward() may be invoked for both self- and cross-
|
|
|
+ attention layers.
|
|
|
+ * For self-attention: query, key and value must be non-None.
|
|
|
+ * For cross-attention:
|
|
|
+ * Query must be non-None
|
|
|
+ * During prefill, key and value must be non-None; key and value
|
|
|
+ get cached for use during decode.
|
|
|
+ * During decode, key and value may be None, since:
|
|
|
+ (1) key and value tensors were cached during prefill, and
|
|
|
+ (2) cross-attention key and value tensors do not grow during
|
|
|
+ decode
|
|
|
+
|
|
|
+ A note on how the attn_type (attention type enum) argument impacts
|
|
|
+ attention forward() behavior:
|
|
|
+
|
|
|
+ * DECODER: normal decoder-only behavior;
|
|
|
+ use decoder self-attention block table
|
|
|
+ * ENCODER: no KV caching; pass encoder sequence
|
|
|
+ attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
|
|
+ max_encoder_seq_len) to kernel, in lieu of decoder
|
|
|
+ sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
|
|
+ * ENCODER_DECODER: cross-attention behavior;
|
|
|
+ use cross-attention block table for caching KVs derived
|
|
|
+ from encoder hidden states; since KV sequence lengths
|
|
|
+ will match encoder sequence lengths, pass encoder sequence
|
|
|
+ attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
|
|
+ max_encoder_seq_len)
|
|
|
+
|
|
|
Args:
|
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
|
|
attn_metadata: Metadata for attention.
|
|
|
+ attn_type: Select attention type, between encoder attention,
|
|
|
+ decoder self-attention, or encoder/decoder cross-
|
|
|
+ attention. Defaults to decoder self-attention,
|
|
|
+ which is the Aphrodite default generally
|
|
|
Returns:
|
|
|
shape = [num_tokens, num_heads * head_size]
|
|
|
"""
|
|
|
- query = query.view(-1, self.num_heads, self.head_size)
|
|
|
- key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
|
- value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
|
- if kv_cache is not None:
|
|
|
+ # Check that appropriate attention metadata attributes are
|
|
|
+ # selected for the desired attention type
|
|
|
+ if (attn_type == AttentionType.ENCODER
|
|
|
+ and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
|
|
+ raise AttributeError("Encoder attention requires setting "
|
|
|
+ "encoder metadata attributes.")
|
|
|
+ elif (attn_type == AttentionType.ENCODER_DECODER
|
|
|
+ and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
|
|
+ raise AttributeError("Encoder/decoder cross-attention "
|
|
|
+ "requires setting cross-attention "
|
|
|
+ "metadata attributes.")
|
|
|
+
|
|
|
+ query = query.view(-1, self.num_heads, self.head_size)
|
|
|
+ if key is not None:
|
|
|
+ assert value is not None
|
|
|
+ key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
|
+ value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
|
+ else:
|
|
|
+ assert value is None
|
|
|
+
|
|
|
+ # Self-attention vs. cross-attention will impact
|
|
|
+ # which KV cache memory-mapping & which
|
|
|
+ # seqlen datastructures we utilize
|
|
|
+
|
|
|
+ if (attn_type != AttentionType.ENCODER and kv_cache is not None):
|
|
|
+ # KV-cache during decoder-self- or
|
|
|
+ # encoder-decoder-cross-attention, but not
|
|
|
+ # during encoder attention.
|
|
|
+ #
|
|
|
+ # Even if there are no new key/value pairs to cache,
|
|
|
+ # we still need to break out key_cache and value_cache
|
|
|
+ # i.e. for later use by paged attention
|
|
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
|
|
kv_cache, self.num_kv_heads, self.head_size)
|
|
|
|
|
|
- # Reshape the input keys and values and store them in the cache.
|
|
|
- # If kv_cache is not provided, the new key and value tensors are
|
|
|
- # not cached. This happens during the initial memory profiling run.
|
|
|
- PagedAttention.write_to_paged_cache(key, value, key_cache,
|
|
|
- value_cache,
|
|
|
- attn_metadata.slot_mapping,
|
|
|
- self.kv_cache_dtype, kv_scale)
|
|
|
-
|
|
|
- num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
|
- num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
- assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
|
- assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
|
+ if (key is not None) and (value is not None):
|
|
|
+
|
|
|
+ if attn_type == AttentionType.ENCODER_DECODER:
|
|
|
+ # Update cross-attention KV cache (prefill-only)
|
|
|
+ # During cross-attention decode, key & value will be None,
|
|
|
+ # preventing this IF-statement branch from running
|
|
|
+ updated_slot_mapping = attn_metadata.cross_slot_mapping
|
|
|
+ else:
|
|
|
+ # Update self-attention KV cache (prefill/decode)
|
|
|
+ updated_slot_mapping = attn_metadata.slot_mapping
|
|
|
+
|
|
|
+ # Reshape the input keys and values and store them in the cache.
|
|
|
+ # If kv_cache is not provided, the new key and value tensors are
|
|
|
+ # not cached. This happens during the initial memory
|
|
|
+ # profiling run.
|
|
|
+ PagedAttention.write_to_paged_cache(key, value, key_cache,
|
|
|
+ value_cache,
|
|
|
+ updated_slot_mapping,
|
|
|
+ self.kv_cache_dtype,
|
|
|
+ kv_scale)
|
|
|
+
|
|
|
+ if attn_type != AttentionType.ENCODER:
|
|
|
+ # Decoder self-attention supports chunked prefill.
|
|
|
+ # Encoder/decoder cross-attention requires no chunked
|
|
|
+ # prefill (100% prefill or 100% decode tokens, no mix)
|
|
|
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
|
+ num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
+ else:
|
|
|
+ # Encoder attention - chunked prefill is not applicable;
|
|
|
+ # derive token-count from query shape & and treat them
|
|
|
+ # as 100% prefill tokens
|
|
|
+ assert attn_metadata.num_encoder_tokens is not None
|
|
|
+ num_prefill_tokens = attn_metadata.num_encoder_tokens
|
|
|
+ num_decode_tokens = 0
|
|
|
+
|
|
|
+ if attn_type == AttentionType.DECODER:
|
|
|
+ # Only enforce this shape-constraint for decoder
|
|
|
+ # self-attention
|
|
|
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
|
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
|
|
|
|
output = torch.empty_like(query)
|
|
|
# Query for decode. KV is not needed because it is already cached.
|
|
|
decode_query = query[num_prefill_tokens:]
|
|
|
# QKV for prefill.
|
|
|
query = query[:num_prefill_tokens]
|
|
|
- key = key[:num_prefill_tokens]
|
|
|
- value = value[:num_prefill_tokens]
|
|
|
+ if key is not None and value is not None:
|
|
|
+ key = key[:num_prefill_tokens]
|
|
|
+ value = value[:num_prefill_tokens]
|
|
|
|
|
|
assert query.shape[0] == num_prefill_tokens
|
|
|
assert decode_query.shape[0] == num_decode_tokens
|
|
@@ -292,10 +571,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
# block tables are empty if the prompt does not have a cached
|
|
|
# prefix.
|
|
|
out = self._run_memory_efficient_xformers_forward(
|
|
|
- query, key, value, prefill_meta)
|
|
|
+ query, key, value, prefill_meta, attn_type=attn_type)
|
|
|
assert out.shape == output[:num_prefill_tokens].shape
|
|
|
output[:num_prefill_tokens] = out
|
|
|
else:
|
|
|
+
|
|
|
+ assert prefill_meta.query_start_loc is not None
|
|
|
+ assert prefill_meta.max_query_len is not None
|
|
|
+
|
|
|
# prefix-enabled attention
|
|
|
# TODO: this triton kernel has regression issue (broke) to
|
|
|
# deal with different data types between KV and FP8 KV cache,
|
|
@@ -318,13 +601,20 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
output[:num_prefill_tokens] = out
|
|
|
|
|
|
if decode_meta := attn_metadata.decode_metadata:
|
|
|
+
|
|
|
+ (
|
|
|
+ seq_lens_arg,
|
|
|
+ max_seq_len_arg,
|
|
|
+ block_tables_arg,
|
|
|
+ ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
|
|
|
+
|
|
|
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
|
|
decode_query,
|
|
|
key_cache,
|
|
|
value_cache,
|
|
|
- decode_meta.block_tables,
|
|
|
- decode_meta.seq_lens_tensor,
|
|
|
- decode_meta.max_decode_seq_len,
|
|
|
+ block_tables_arg,
|
|
|
+ seq_lens_arg,
|
|
|
+ max_seq_len_arg,
|
|
|
self.kv_cache_dtype,
|
|
|
self.num_kv_heads,
|
|
|
self.scale,
|
|
@@ -341,6 +631,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
key: torch.Tensor,
|
|
|
value: torch.Tensor,
|
|
|
attn_metadata: XFormersMetadata,
|
|
|
+ attn_type: AttentionType = AttentionType.DECODER,
|
|
|
) -> torch.Tensor:
|
|
|
"""Attention for 1D query of multiple prompts. Multiple prompt
|
|
|
tokens are flattened in to `query` input.
|
|
@@ -354,8 +645,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
|
|
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
|
|
attn_metadata: Metadata for attention.
|
|
|
+ attn_type: Select attention type, between encoder attention,
|
|
|
+ decoder self-attention, or encoder/decoder cross-
|
|
|
+ attention. Defaults to decoder self-attention,
|
|
|
+ which is the Aphrodite default generally
|
|
|
"""
|
|
|
- assert attn_metadata.seq_lens is not None
|
|
|
+
|
|
|
original_query = query
|
|
|
if self.num_kv_heads != self.num_heads:
|
|
|
# GQA/MQA requires the shape [B, M, G, H, K].
|
|
@@ -373,18 +668,39 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
# Set attention bias if not provided. This typically happens at
|
|
|
# the very attention layer of every iteration.
|
|
|
# FIXME: This is a hack.
|
|
|
- if attn_metadata.attn_bias is None:
|
|
|
+ attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
|
|
+ if attn_bias is None:
|
|
|
if self.alibi_slopes is None:
|
|
|
- attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
|
|
- attn_metadata.seq_lens)
|
|
|
+ if (attn_type == AttentionType.ENCODER_DECODER):
|
|
|
+ assert attn_metadata.seq_lens is not None
|
|
|
+ assert attn_metadata.encoder_seq_lens is not None
|
|
|
+
|
|
|
+ # Default enc/dec cross-attention mask is non-causal
|
|
|
+ attn_bias = BlockDiagonalMask.from_seqlens(
|
|
|
+ attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
|
|
|
+ elif attn_type == AttentionType.ENCODER:
|
|
|
+ assert attn_metadata.encoder_seq_lens is not None
|
|
|
+
|
|
|
+ # Default encoder self-attention mask is non-causal
|
|
|
+ attn_bias = BlockDiagonalMask.from_seqlens(
|
|
|
+ attn_metadata.encoder_seq_lens)
|
|
|
+ else:
|
|
|
+ assert attn_metadata.seq_lens is not None
|
|
|
+
|
|
|
+ # Default decoder self-attention mask is causal
|
|
|
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
|
|
+ attn_metadata.seq_lens)
|
|
|
if self.sliding_window is not None:
|
|
|
attn_bias = attn_bias.make_local_attention(
|
|
|
self.sliding_window)
|
|
|
- attn_metadata.attn_bias = [attn_bias]
|
|
|
+ attn_bias = [attn_bias]
|
|
|
else:
|
|
|
- attn_metadata.attn_bias = _make_alibi_bias(
|
|
|
- self.alibi_slopes, self.num_kv_heads, query.dtype,
|
|
|
- attn_metadata.seq_lens)
|
|
|
+ assert attn_metadata.seq_lens is not None
|
|
|
+ attn_bias = _make_alibi_bias(self.alibi_slopes,
|
|
|
+ self.num_kv_heads, query.dtype,
|
|
|
+ attn_metadata.seq_lens)
|
|
|
+
|
|
|
+ _set_attn_bias(attn_metadata, attn_bias, attn_type)
|
|
|
|
|
|
# No alibi slopes.
|
|
|
# TODO: Too many view operations. Let's try to reduce
|
|
@@ -398,7 +714,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
query,
|
|
|
key,
|
|
|
value,
|
|
|
- attn_bias=attn_metadata.attn_bias[0],
|
|
|
+ attn_bias=attn_bias[0],
|
|
|
p=0.0,
|
|
|
scale=self.scale)
|
|
|
return out.view_as(original_query)
|
|
@@ -407,6 +723,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
# FIXME: Because xformers does not support dynamic sequence
|
|
|
# lengths with custom attention bias, we process each prompt one by
|
|
|
# one. This is inefficient, especially when we have many short prompts.
|
|
|
+ assert attn_metadata.seq_lens is not None
|
|
|
output = torch.empty_like(original_query)
|
|
|
start = 0
|
|
|
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
|
@@ -415,7 +732,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|
|
query[None, start:end],
|
|
|
key[None, start:end],
|
|
|
value[None, start:end],
|
|
|
- attn_bias=attn_metadata.attn_bias[i],
|
|
|
+ attn_bias=attn_bias[i],
|
|
|
p=0.0,
|
|
|
scale=self.scale)
|
|
|
# TODO: Unnecessary copy. Optimize.
|
|
@@ -429,11 +746,11 @@ def _make_alibi_bias(
|
|
|
num_kv_heads: int,
|
|
|
dtype: torch.dtype,
|
|
|
seq_lens: List[int],
|
|
|
-) -> LowerTriangularMaskWithTensorBias:
|
|
|
- attn_biases = []
|
|
|
+) -> List[AttentionBias]:
|
|
|
+ attn_biases: List[AttentionBias] = []
|
|
|
for seq_len in seq_lens:
|
|
|
bias = torch.arange(seq_len, dtype=dtype)
|
|
|
- # NOTE(zhuohan): HF uses
|
|
|
+ # NOTE: HF uses
|
|
|
# `bias = bias[None, :].repeat(seq_len, 1)`
|
|
|
# here. We find that both biases give the same results, but
|
|
|
# the bias below more accurately follows the original ALiBi
|