flash_attn.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """Attention layer with Flash and PagedAttention."""
  2. from typing import List, Optional
  3. from flash_attn import flash_attn_func
  4. import torch
  5. from aphrodite.modeling.metadata import InputMetadata
  6. from aphrodite.modeling.layers.attention.ops.paged_attn import (
  7. PagedAttentionImpl)
  8. class FlashAttentionBackend:
  9. def __init__(
  10. self,
  11. num_heads: int,
  12. head_size: int,
  13. scale: float,
  14. num_kv_heads: Optional[int] = None,
  15. alibi_slopes: Optional[List[float]] = None,
  16. sliding_window: Optional[int] = None,
  17. ) -> None:
  18. self.num_heads = num_heads
  19. self.head_size = head_size
  20. self.scale = float(scale)
  21. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  22. self.sliding_window = sliding_window
  23. if alibi_slopes is not None:
  24. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  25. self.alibi_slopes = alibi_slopes
  26. assert self.num_heads % self.num_kv_heads == 0
  27. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  28. suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
  29. if head_size not in suppored_head_sizes:
  30. raise ValueError(
  31. f"Head size {head_size} is not supported by PagedAttention. "
  32. f"Supported head sizes are: {suppored_head_sizes}.")
  33. self.sliding_window = ((self.sliding_window, self.sliding_window) if
  34. self.sliding_window is not None else (-1, -1))
  35. def forward(
  36. self,
  37. query: torch.Tensor,
  38. key: torch.Tensor,
  39. value: torch.Tensor,
  40. key_cache: Optional[torch.Tensor],
  41. value_cache: Optional[torch.Tensor],
  42. input_metadata: InputMetadata,
  43. ) -> torch.Tensor:
  44. """Forward pass with FlashAttention and PagedAttention.
  45. Args:
  46. query: shape = [batch_size, seq_len, num_heads * head_size]
  47. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  48. value: shape = [batch_size, seq_len, num_kv_heads * head_size]
  49. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  50. block_size, x]
  51. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  52. block_size]
  53. input_metadata: metadata for the inputs.
  54. Returns:
  55. shape = [batch_size, seq_len, num_heads * head_size]
  56. """
  57. batch_size, seq_len, hidden_size = query.shape
  58. # Reshape the query, key, and value tensors.
  59. query = query.view(-1, self.num_heads, self.head_size)
  60. key = key.view(-1, self.num_kv_heads, self.head_size)
  61. value = value.view(-1, self.num_kv_heads, self.head_size)
  62. # Reshape the keys and values and store them in the cache.
  63. # If key_cache and value_cache are not provided, the new key and value
  64. # vectors will not be cached. This happens during the initial memory
  65. # profiling run.
  66. if key_cache is not None and value_cache is not None:
  67. PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
  68. value_cache, input_metadata)
  69. if input_metadata.is_prompt:
  70. # Prompt run.
  71. if (key_cache is None or value_cache is None
  72. or input_metadata.block_tables.numel() == 0):
  73. # normal attention
  74. query = query.unflatten(0, (batch_size, seq_len))
  75. key = key.unflatten(0, (batch_size, seq_len))
  76. value = value.unflatten(0, (batch_size, seq_len))
  77. output = flash_attn_func(
  78. query,
  79. key,
  80. value,
  81. softmax_scale=self.scale,
  82. causal=True,
  83. window_size=self.sliding_window,
  84. alibi_slopes=self.alibi_slopes,
  85. )
  86. else:
  87. # prefix-enabled attention
  88. output = PagedAttentionImpl.forward_prefix(
  89. query,
  90. key,
  91. value,
  92. key_cache,
  93. value_cache,
  94. input_metadata,
  95. self.alibi_slopes,
  96. )
  97. else:
  98. # Decoding run.
  99. output = PagedAttentionImpl.forward_decode(
  100. query,
  101. key_cache,
  102. value_cache,
  103. input_metadata,
  104. self.num_kv_heads,
  105. self.scale,
  106. self.alibi_slopes,
  107. )
  108. # Reshape the output tensor.
  109. return output.view(batch_size, seq_len, hidden_size)