123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- """Attention layer."""
- from typing import Any, Dict, List, Optional
- import torch
- import torch.nn as nn
- from aphrodite.attention import AttentionMetadata, AttentionType
- from aphrodite.attention.selector import get_attn_backend
- from aphrodite.common.config import CacheConfig
- from aphrodite.quantization.base_config import QuantizationConfig
- from aphrodite.quantization.kv_cache import BaseKVCacheMethod
- class Attention(nn.Module):
- """Attention layer.
- This class takes query, key, and value tensors as input. The input tensors
- can either contain prompt tokens or generation tokens.
- The class does the following:
- 1. Store the input key and value tensors in the KV cache.
- 2. Perform (multi-head/multi-query/grouped-query) attention.
- 3. Return the output tensor.
- """
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: Optional[int] = None,
- alibi_slopes: Optional[List[float]] = None,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- blocksparse_params: Optional[Dict[str, Any]] = None,
- logits_soft_cap: Optional[float] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- if cache_config is not None:
- kv_cache_dtype = cache_config.cache_dtype
- block_size = cache_config.block_size
- sliding_window = cache_config.sliding_window
- is_attention_free = cache_config.is_attention_free
- else:
- kv_cache_dtype = "auto"
- block_size = 16
- sliding_window = None
- is_attention_free = False
- if num_kv_heads is None:
- num_kv_heads = num_heads
- # The default k/v_scale is set to 1.0. This is ignored
- # when kv-cache is not fp8, and should be used with
- # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
- # expect the pre-quantized k/v_scale to be loaded along
- # with the model weights.
- self.kv_cache_dtype = kv_cache_dtype
- self._k_scale = 1.0
- self._v_scale = 1.0
- quant_method = quant_config.get_quant_method(
- self, prefix=prefix) if quant_config else None
- if quant_method is not None:
- assert isinstance(quant_method, BaseKVCacheMethod)
- # TODO: kv cache dtype should be specified in the FP8
- # checkpoint config and become the "auto" behavior
- if self.kv_cache_dtype == "fp8_e5m2":
- raise ValueError("fp8_e5m2 kv-cache is not supported with "
- "fp8 checkpoints.")
- # If quantization is enabled, we make "k_scale" and "v_scale"
- # parameters so that it can be loaded from the model checkpoint.
- # The k/v_scale will then be converted back to native float32
- # values after weight loading.
- self.quant_method = quant_method
- self.quant_method.create_weights(self)
- # During model initialization, the default dtype is set as the model
- # weight and activation dtype.
- dtype = torch.get_default_dtype()
- attn_backend = get_attn_backend(head_size, sliding_window, dtype,
- kv_cache_dtype, block_size,
- is_attention_free, blocksparse_params
- is not None)
- impl_cls = attn_backend.get_impl_cls()
- self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
- alibi_slopes, sliding_window, kv_cache_dtype,
- blocksparse_params, logits_soft_cap)
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: Optional[torch.Tensor],
- attn_metadata: AttentionMetadata,
- attn_type: AttentionType = AttentionType.DECODER,
- ) -> torch.Tensor:
- return self.impl.forward(query,
- key,
- value,
- kv_cache,
- attn_metadata,
- self._k_scale,
- self._v_scale,
- attn_type=attn_type)
- def extra_repr(self) -> str:
- s = f"head_size={self.impl.head_size}" # type: ignore
- s += f", num_heads={self.impl.num_heads}" # type: ignore
- s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
- s += f", scale={self.impl.scale}" # type: ignore
- s += f", backend={self.impl.__class__.__name__}"
- return s
|