layer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """Attention layer."""
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention.backends.abstract import (AttentionMetadata,
  6. AttentionType)
  7. from aphrodite.attention.selector import get_attn_backend
  8. from aphrodite.common.config import CacheConfig
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.fp8 import Fp8KVCacheMethod
  11. class Attention(nn.Module):
  12. """Attention layer.
  13. This class takes query, key, and value tensors as input. The input tensors
  14. can either contain prompt tokens or generation tokens.
  15. The class does the following:
  16. 1. Store the input key and value tensors in the KV cache.
  17. 2. Perform (multi-head/multi-query/grouped-query) attention.
  18. 3. Return the output tensor.
  19. """
  20. def __init__(
  21. self,
  22. num_heads: int,
  23. head_size: int,
  24. scale: float,
  25. num_kv_heads: Optional[int] = None,
  26. alibi_slopes: Optional[List[float]] = None,
  27. cache_config: Optional[CacheConfig] = None,
  28. quant_config: Optional[QuantizationConfig] = None,
  29. blocksparse_params: Optional[Dict[str, Any]] = None,
  30. prefix: str = "",
  31. ) -> None:
  32. super().__init__()
  33. if cache_config is not None:
  34. kv_cache_dtype = cache_config.cache_dtype
  35. block_size = cache_config.block_size
  36. sliding_window = cache_config.sliding_window
  37. else:
  38. kv_cache_dtype = "auto"
  39. block_size = 16
  40. sliding_window = None
  41. if num_kv_heads is None:
  42. num_kv_heads = num_heads
  43. # The default k/v_scale is set to 1.0. This is ignored
  44. # when kv-cache is not fp8, and should be used with
  45. # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
  46. # expect the pre-quantized k/v_scale to be loaded along
  47. # with the model weights.
  48. self.kv_cache_dtype = kv_cache_dtype
  49. self._k_scale = 1.0
  50. self._v_scale = 1.0
  51. quant_method = quant_config.get_quant_method(
  52. self, prefix=prefix) if quant_config else None
  53. if quant_method is not None:
  54. assert isinstance(quant_method, Fp8KVCacheMethod)
  55. # TODO: kv cache dtype should be specified in the FP8
  56. # checkpoint config and become the "auto" behavior
  57. if "fp8" in self.kv_cache_dtype:
  58. if self.kv_cache_dtype == "fp8_e5m2":
  59. raise ValueError("fp8_e5m2 kv-cache is not supported with "
  60. "fp8 checkpoints.")
  61. # When FP8 quantization is enabled, we make a parameter
  62. # "k/v_scale" so that it can be loaded from FP8 checkpoint.
  63. # The k/v_scale will then be converted back to self._k/v_scale
  64. # in a native float32 value after weight loading.
  65. self.quant_method = quant_method
  66. self.quant_method.create_weights(self)
  67. # During model initialization, the default dtype is set as the model
  68. # weight and activation dtype.
  69. dtype = torch.get_default_dtype()
  70. attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
  71. sliding_window, dtype, kv_cache_dtype,
  72. block_size, blocksparse_params
  73. is not None)
  74. impl_cls = attn_backend.get_impl_cls()
  75. self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
  76. alibi_slopes, sliding_window, kv_cache_dtype,
  77. blocksparse_params)
  78. def forward(
  79. self,
  80. query: torch.Tensor,
  81. key: torch.Tensor,
  82. value: torch.Tensor,
  83. kv_cache: Optional[torch.Tensor],
  84. attn_metadata: AttentionMetadata,
  85. attn_type: AttentionType = AttentionType.DECODER,
  86. ) -> torch.Tensor:
  87. return self.impl.forward(query,
  88. key,
  89. value,
  90. kv_cache,
  91. attn_metadata,
  92. self._k_scale,
  93. self._v_scale,
  94. attn_type=attn_type)
  95. def extra_repr(self) -> str:
  96. s = f"head_size={self.impl.head_size}" # type: ignore
  97. s += f", num_heads={self.impl.num_heads}" # type: ignore
  98. s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
  99. s += f", scale={self.impl.scale}" # type: ignore
  100. s += f", backend={self.impl.__class__.__name__}"
  101. return s