123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- """Attention layer with Flash and PagedAttention."""
- from typing import List, Optional
- from flash_attn import flash_attn_func
- import torch
- from aphrodite.modeling.metadata import InputMetadata
- from aphrodite.modeling.layers.attention.ops.paged_attn import (
- PagedAttentionImpl)
- class FlashAttentionBackend:
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: Optional[int] = None,
- alibi_slopes: Optional[List[float]] = None,
- sliding_window: Optional[int] = None,
- ) -> None:
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
- self.sliding_window = sliding_window
- if alibi_slopes is not None:
- alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
- self.alibi_slopes = alibi_slopes
- assert self.num_heads % self.num_kv_heads == 0
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
- if head_size not in suppored_head_sizes:
- raise ValueError(
- f"Head size {head_size} is not supported by PagedAttention. "
- f"Supported head sizes are: {suppored_head_sizes}.")
- self.sliding_window = ((self.sliding_window, self.sliding_window) if
- self.sliding_window is not None else (-1, -1))
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- key_cache: Optional[torch.Tensor],
- value_cache: Optional[torch.Tensor],
- input_metadata: InputMetadata,
- ) -> torch.Tensor:
- """Forward pass with FlashAttention and PagedAttention.
- Args:
- query: shape = [batch_size, seq_len, num_heads * head_size]
- key: shape = [batch_size, seq_len, num_kv_heads * head_size]
- value: shape = [batch_size, seq_len, num_kv_heads * head_size]
- key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
- block_size, x]
- value_cache: shape = [num_blocks, num_kv_heads, head_size,
- block_size]
- input_metadata: metadata for the inputs.
- Returns:
- shape = [batch_size, seq_len, num_heads * head_size]
- """
- batch_size, seq_len, hidden_size = query.shape
- # Reshape the query, key, and value tensors.
- query = query.view(-1, self.num_heads, self.head_size)
- key = key.view(-1, self.num_kv_heads, self.head_size)
- value = value.view(-1, self.num_kv_heads, self.head_size)
- # Reshape the keys and values and store them in the cache.
- # If key_cache and value_cache are not provided, the new key and value
- # vectors will not be cached. This happens during the initial memory
- # profiling run.
- if key_cache is not None and value_cache is not None:
- PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
- value_cache, input_metadata)
- if input_metadata.is_prompt:
- # Prompt run.
- if (key_cache is None or value_cache is None
- or input_metadata.block_tables.numel() == 0):
- # normal attention
- query = query.unflatten(0, (batch_size, seq_len))
- key = key.unflatten(0, (batch_size, seq_len))
- value = value.unflatten(0, (batch_size, seq_len))
- output = flash_attn_func(
- query,
- key,
- value,
- softmax_scale=self.scale,
- causal=True,
- window_size=self.sliding_window,
- alibi_slopes=self.alibi_slopes,
- )
- else:
- # prefix-enabled attention
- output = PagedAttentionImpl.forward_prefix(
- query,
- key,
- value,
- key_cache,
- value_cache,
- input_metadata,
- self.alibi_slopes,
- )
- else:
- # Decoding run.
- output = PagedAttentionImpl.forward_decode(
- query,
- key_cache,
- value_cache,
- input_metadata,
- self.num_kv_heads,
- self.scale,
- self.alibi_slopes,
- )
- # Reshape the output tensor.
- return output.view(batch_size, seq_len, hidden_size)
|