layer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """Attention layer."""
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention.backends.abstract import (AttentionMetadata,
  6. AttentionType)
  7. from aphrodite.attention.selector import get_attn_backend
  8. from aphrodite.common.config import CacheConfig
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  11. class Attention(nn.Module):
  12. """Attention layer.
  13. This class takes query, key, and value tensors as input. The input tensors
  14. can either contain prompt tokens or generation tokens.
  15. The class does the following:
  16. 1. Store the input key and value tensors in the KV cache.
  17. 2. Perform (multi-head/multi-query/grouped-query) attention.
  18. 3. Return the output tensor.
  19. """
  20. def __init__(
  21. self,
  22. num_heads: int,
  23. head_size: int,
  24. scale: float,
  25. num_kv_heads: Optional[int] = None,
  26. alibi_slopes: Optional[List[float]] = None,
  27. cache_config: Optional[CacheConfig] = None,
  28. quant_config: Optional[QuantizationConfig] = None,
  29. blocksparse_params: Optional[Dict[str, Any]] = None,
  30. logits_soft_cap: Optional[float] = None,
  31. prefix: str = "",
  32. ) -> None:
  33. super().__init__()
  34. if cache_config is not None:
  35. kv_cache_dtype = cache_config.cache_dtype
  36. block_size = cache_config.block_size
  37. sliding_window = cache_config.sliding_window
  38. else:
  39. kv_cache_dtype = "auto"
  40. block_size = 16
  41. sliding_window = None
  42. if num_kv_heads is None:
  43. num_kv_heads = num_heads
  44. # The default k/v_scale is set to 1.0. This is ignored
  45. # when kv-cache is not fp8, and should be used with
  46. # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
  47. # expect the pre-quantized k/v_scale to be loaded along
  48. # with the model weights.
  49. self.kv_cache_dtype = kv_cache_dtype
  50. self._k_scale = 1.0
  51. self._v_scale = 1.0
  52. quant_method = quant_config.get_quant_method(
  53. self, prefix=prefix) if quant_config else None
  54. if quant_method is not None:
  55. assert isinstance(quant_method, BaseKVCacheMethod)
  56. # TODO: kv cache dtype should be specified in the FP8
  57. # checkpoint config and become the "auto" behavior
  58. if self.kv_cache_dtype == "fp8_e5m2":
  59. raise ValueError("fp8_e5m2 kv-cache is not supported with "
  60. "fp8 checkpoints.")
  61. # If quantization is enabled, we make "k_scale" and "v_scale"
  62. # parameters so that it can be loaded from the model checkpoint.
  63. # The k/v_scale will then be converted back to native float32
  64. # values after weight loading.
  65. self.quant_method = quant_method
  66. self.quant_method.create_weights(self)
  67. # During model initialization, the default dtype is set as the model
  68. # weight and activation dtype.
  69. dtype = torch.get_default_dtype()
  70. attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
  71. sliding_window, dtype, kv_cache_dtype,
  72. block_size, blocksparse_params
  73. is not None)
  74. impl_cls = attn_backend.get_impl_cls()
  75. self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
  76. alibi_slopes, sliding_window, kv_cache_dtype,
  77. blocksparse_params, logits_soft_cap)
  78. def forward(
  79. self,
  80. query: torch.Tensor,
  81. key: torch.Tensor,
  82. value: torch.Tensor,
  83. kv_cache: Optional[torch.Tensor],
  84. attn_metadata: AttentionMetadata,
  85. attn_type: AttentionType = AttentionType.DECODER,
  86. ) -> torch.Tensor:
  87. return self.impl.forward(query,
  88. key,
  89. value,
  90. kv_cache,
  91. attn_metadata,
  92. self._k_scale,
  93. self._v_scale,
  94. attn_type=attn_type)
  95. def extra_repr(self) -> str:
  96. s = f"head_size={self.impl.head_size}" # type: ignore
  97. s += f", num_heads={self.impl.num_heads}" # type: ignore
  98. s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
  99. s += f", scale={self.impl.scale}" # type: ignore
  100. s += f", backend={self.impl.__class__.__name__}"
  101. return s