layer.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """Attention layer."""
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import AttentionMetadata, AttentionType
  6. from aphrodite.attention.selector import get_attn_backend
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.quantization.base_config import QuantizationConfig
  9. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  10. class Attention(nn.Module):
  11. """Attention layer.
  12. This class takes query, key, and value tensors as input. The input tensors
  13. can either contain prompt tokens or generation tokens.
  14. The class does the following:
  15. 1. Store the input key and value tensors in the KV cache.
  16. 2. Perform (multi-head/multi-query/grouped-query) attention.
  17. 3. Return the output tensor.
  18. """
  19. def __init__(
  20. self,
  21. num_heads: int,
  22. head_size: int,
  23. scale: float,
  24. num_kv_heads: Optional[int] = None,
  25. alibi_slopes: Optional[List[float]] = None,
  26. cache_config: Optional[CacheConfig] = None,
  27. quant_config: Optional[QuantizationConfig] = None,
  28. blocksparse_params: Optional[Dict[str, Any]] = None,
  29. logits_soft_cap: Optional[float] = None,
  30. prefix: str = "",
  31. ) -> None:
  32. super().__init__()
  33. if cache_config is not None:
  34. kv_cache_dtype = cache_config.cache_dtype
  35. block_size = cache_config.block_size
  36. sliding_window = cache_config.sliding_window
  37. else:
  38. kv_cache_dtype = "auto"
  39. block_size = 16
  40. sliding_window = None
  41. if num_kv_heads is None:
  42. num_kv_heads = num_heads
  43. # The default k/v_scale is set to 1.0. This is ignored
  44. # when kv-cache is not fp8, and should be used with
  45. # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
  46. # expect the pre-quantized k/v_scale to be loaded along
  47. # with the model weights.
  48. self.kv_cache_dtype = kv_cache_dtype
  49. self._k_scale = 1.0
  50. self._v_scale = 1.0
  51. quant_method = quant_config.get_quant_method(
  52. self, prefix=prefix) if quant_config else None
  53. if quant_method is not None:
  54. assert isinstance(quant_method, BaseKVCacheMethod)
  55. # TODO: kv cache dtype should be specified in the FP8
  56. # checkpoint config and become the "auto" behavior
  57. if self.kv_cache_dtype == "fp8_e5m2":
  58. raise ValueError("fp8_e5m2 kv-cache is not supported with "
  59. "fp8 checkpoints.")
  60. # If quantization is enabled, we make "k_scale" and "v_scale"
  61. # parameters so that it can be loaded from the model checkpoint.
  62. # The k/v_scale will then be converted back to native float32
  63. # values after weight loading.
  64. self.quant_method = quant_method
  65. self.quant_method.create_weights(self)
  66. # During model initialization, the default dtype is set as the model
  67. # weight and activation dtype.
  68. dtype = torch.get_default_dtype()
  69. attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
  70. sliding_window, dtype, kv_cache_dtype,
  71. block_size, blocksparse_params
  72. is not None)
  73. impl_cls = attn_backend.get_impl_cls()
  74. self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
  75. alibi_slopes, sliding_window, kv_cache_dtype,
  76. blocksparse_params, logits_soft_cap)
  77. def forward(
  78. self,
  79. query: torch.Tensor,
  80. key: torch.Tensor,
  81. value: torch.Tensor,
  82. kv_cache: Optional[torch.Tensor],
  83. attn_metadata: AttentionMetadata,
  84. attn_type: AttentionType = AttentionType.DECODER,
  85. ) -> torch.Tensor:
  86. return self.impl.forward(query,
  87. key,
  88. value,
  89. kv_cache,
  90. attn_metadata,
  91. self._k_scale,
  92. self._v_scale,
  93. attn_type=attn_type)
  94. def extra_repr(self) -> str:
  95. s = f"head_size={self.impl.head_size}" # type: ignore
  96. s += f", num_heads={self.impl.num_heads}" # type: ignore
  97. s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
  98. s += f", scale={self.impl.scale}" # type: ignore
  99. s += f", backend={self.impl.__class__.__name__}"
  100. return s