1
0

layer.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """Attention layer."""
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention.backends.abstract import AttentionMetadata
  6. from aphrodite.attention.selector import get_attn_backend
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.quantization.base_config import QuantizationConfig
  9. class Attention(nn.Module):
  10. """Attention layer.
  11. This class takes query, key, and value tensors as input. The input tensors
  12. can either contain prompt tokens or generation tokens.
  13. The class does the following:
  14. 1. Store the input key and value tensors in the KV cache.
  15. 2. Perform (multi-head/multi-query/grouped-query) attention.
  16. 3. Return the output tensor.
  17. """
  18. def __init__(
  19. self,
  20. num_heads: int,
  21. head_size: int,
  22. scale: float,
  23. num_kv_heads: Optional[int] = None,
  24. alibi_slopes: Optional[List[float]] = None,
  25. cache_config: Optional[CacheConfig] = None,
  26. quant_config: Optional[QuantizationConfig] = None,
  27. blocksparse_params: Optional[Dict[str, Any]] = None,
  28. ) -> None:
  29. super().__init__()
  30. if cache_config is not None:
  31. kv_cache_dtype = cache_config.cache_dtype
  32. block_size = cache_config.block_size
  33. sliding_window = cache_config.sliding_window
  34. else:
  35. kv_cache_dtype = "auto"
  36. block_size = 16
  37. sliding_window = None
  38. if num_kv_heads is None:
  39. num_kv_heads = num_heads
  40. # The default kv_scale is set to 1.0. This is ignored
  41. # when kv-cache is not fp8, and should be used with
  42. # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
  43. # expect the pre-quantized kv_scale to be loaded along
  44. # with the model weights.
  45. self.kv_cache_dtype = kv_cache_dtype
  46. self._kv_scale = 1.0
  47. quant_method = quant_config.get_quant_method(
  48. self) if quant_config else None
  49. if quant_method is not None:
  50. if self.kv_cache_dtype == "fp8_e5m2":
  51. raise ValueError("fp8_e5m2 kv-cache is not supported with "
  52. "fp8 checkpoints.")
  53. # When FP8 quantization is enabled, we make a parameter
  54. # "kv_scale" so that it can be loaded from FP8 checkpoint.
  55. # The kv_scale will then be converted back
  56. # to self._kv_scale in a native float32 value after weight loading.
  57. self.quant_method = quant_method
  58. self.quant_method.create_weights(self)
  59. # During model initialization, the default dtype is set as the model
  60. # weight and activation dtype.
  61. dtype = torch.get_default_dtype()
  62. attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
  63. sliding_window, dtype, kv_cache_dtype,
  64. block_size, blocksparse_params
  65. is not None)
  66. impl_cls = attn_backend.get_impl_cls()
  67. self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
  68. alibi_slopes, sliding_window, kv_cache_dtype,
  69. blocksparse_params)
  70. def forward(
  71. self,
  72. query: torch.Tensor,
  73. key: torch.Tensor,
  74. value: torch.Tensor,
  75. kv_cache: Optional[torch.Tensor],
  76. attn_metadata: AttentionMetadata,
  77. ) -> torch.Tensor:
  78. return self.impl.forward(query, key, value, kv_cache, attn_metadata,
  79. self._kv_scale)
  80. def extra_repr(self) -> str:
  81. s = f"head_size={self.impl.head_size}" # type: ignore
  82. s += f", num_heads={self.impl.num_heads}" # type: ignore
  83. s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
  84. s += f", scale={self.impl.scale}" # type: ignore
  85. s += f", backend={self.impl.__class__.__name__}"
  86. return s