layer.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. """Attention layer."""
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention.backends.abstract import (AttentionMetadata,
  6. AttentionMetadataPerStage)
  7. from aphrodite.attention.selector import get_attn_backend
  8. class Attention(nn.Module):
  9. """Attention layer.
  10. This class takes query, key, and value tensors as input. The input tensors
  11. can either contain prompt tokens or generation tokens.
  12. The class does the following:
  13. 1. Store the input key and value tensors in the KV cache.
  14. 2. Perform (multi-head/multi-query/grouped-query) attention.
  15. 3. Return the output tensor.
  16. """
  17. def __init__(
  18. self,
  19. num_heads: int,
  20. head_size: int,
  21. scale: float,
  22. num_kv_heads: Optional[int] = None,
  23. alibi_slopes: Optional[List[float]] = None,
  24. sliding_window: Optional[int] = None,
  25. ) -> None:
  26. super().__init__()
  27. self.backend = get_attn_backend(torch.get_default_dtype())
  28. impl_cls = self.backend.get_impl_cls()
  29. self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
  30. alibi_slopes, sliding_window)
  31. def forward(
  32. self,
  33. query: torch.Tensor,
  34. key: torch.Tensor,
  35. value: torch.Tensor,
  36. kv_cache: Optional[torch.Tensor],
  37. attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
  38. kv_scale: float = 1.0,
  39. ) -> torch.Tensor:
  40. return self.impl.forward(query, key, value, kv_cache, attn_metadata,
  41. kv_scale)