layer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """Attention layer."""
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import AttentionMetadata, AttentionType
  6. from aphrodite.attention.selector import get_attn_backend
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.quantization.base_config import QuantizationConfig
  9. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  10. class Attention(nn.Module):
  11. """Attention layer.
  12. This class takes query, key, and value tensors as input. The input tensors
  13. can either contain prompt tokens or generation tokens.
  14. The class does the following:
  15. 1. Store the input key and value tensors in the KV cache.
  16. 2. Perform (multi-head/multi-query/grouped-query) attention.
  17. 3. Return the output tensor.
  18. """
  19. def __init__(
  20. self,
  21. num_heads: int,
  22. head_size: int,
  23. scale: float,
  24. num_kv_heads: Optional[int] = None,
  25. alibi_slopes: Optional[List[float]] = None,
  26. cache_config: Optional[CacheConfig] = None,
  27. quant_config: Optional[QuantizationConfig] = None,
  28. blocksparse_params: Optional[Dict[str, Any]] = None,
  29. logits_soft_cap: Optional[float] = None,
  30. prefix: str = "",
  31. ) -> None:
  32. super().__init__()
  33. if cache_config is not None:
  34. kv_cache_dtype = cache_config.cache_dtype
  35. block_size = cache_config.block_size
  36. sliding_window = cache_config.sliding_window
  37. is_attention_free = cache_config.is_attention_free
  38. else:
  39. kv_cache_dtype = "auto"
  40. block_size = 16
  41. sliding_window = None
  42. is_attention_free = False
  43. if num_kv_heads is None:
  44. num_kv_heads = num_heads
  45. # The default k/v_scale is set to 1.0. This is ignored
  46. # when kv-cache is not fp8, and should be used with
  47. # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
  48. # expect the pre-quantized k/v_scale to be loaded along
  49. # with the model weights.
  50. self.kv_cache_dtype = kv_cache_dtype
  51. self._k_scale = 1.0
  52. self._v_scale = 1.0
  53. quant_method = quant_config.get_quant_method(
  54. self, prefix=prefix) if quant_config else None
  55. if quant_method is not None:
  56. assert isinstance(quant_method, BaseKVCacheMethod)
  57. # TODO: kv cache dtype should be specified in the FP8
  58. # checkpoint config and become the "auto" behavior
  59. if self.kv_cache_dtype == "fp8_e5m2":
  60. raise ValueError("fp8_e5m2 kv-cache is not supported with "
  61. "fp8 checkpoints.")
  62. # If quantization is enabled, we make "k_scale" and "v_scale"
  63. # parameters so that it can be loaded from the model checkpoint.
  64. # The k/v_scale will then be converted back to native float32
  65. # values after weight loading.
  66. self.quant_method = quant_method
  67. self.quant_method.create_weights(self)
  68. # During model initialization, the default dtype is set as the model
  69. # weight and activation dtype.
  70. dtype = torch.get_default_dtype()
  71. attn_backend = get_attn_backend(head_size, sliding_window, dtype,
  72. kv_cache_dtype, block_size,
  73. is_attention_free, blocksparse_params
  74. is not None)
  75. impl_cls = attn_backend.get_impl_cls()
  76. self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
  77. alibi_slopes, sliding_window, kv_cache_dtype,
  78. blocksparse_params, logits_soft_cap)
  79. def forward(
  80. self,
  81. query: torch.Tensor,
  82. key: torch.Tensor,
  83. value: torch.Tensor,
  84. kv_cache: Optional[torch.Tensor],
  85. attn_metadata: AttentionMetadata,
  86. attn_type: AttentionType = AttentionType.DECODER,
  87. ) -> torch.Tensor:
  88. return self.impl.forward(query,
  89. key,
  90. value,
  91. kv_cache,
  92. attn_metadata,
  93. self._k_scale,
  94. self._v_scale,
  95. attn_type=attn_type)
  96. def extra_repr(self) -> str:
  97. s = f"head_size={self.impl.head_size}" # type: ignore
  98. s += f", num_heads={self.impl.num_heads}" # type: ignore
  99. s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
  100. s += f", scale={self.impl.scale}" # type: ignore
  101. s += f", backend={self.impl.__class__.__name__}"
  102. return s