flash_attn.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """Attention layer with Flash and PagedAttention."""
  2. from typing import List, Optional
  3. from flash_attn import flash_attn_varlen_func
  4. import torch
  5. from aphrodite.modeling.metadata import InputMetadata
  6. from aphrodite.modeling.layers.attention.ops.paged_attn import (
  7. PagedAttentionImpl)
  8. class FlashAttentionBackend:
  9. """
  10. If the input tensors contain prompt tokens, the layout is as follows:
  11. |<--------------- num_prompt_tokens -------------->|
  12. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  13. Otherwise, the layout is as follows:
  14. |<------------------ num_generation_tokens (M) ----------------->|
  15. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  16. Generation tokens can contain padding when cuda-graph is used.
  17. Currently, prompt tokens don't contain any padding.
  18. The prompts might have different lengths, while the generation tokens
  19. always have length 1.
  20. """
  21. def __init__(
  22. self,
  23. num_heads: int,
  24. head_size: int,
  25. scale: float,
  26. num_kv_heads: Optional[int] = None,
  27. alibi_slopes: Optional[List[float]] = None,
  28. sliding_window: Optional[int] = None,
  29. ) -> None:
  30. self.num_heads = num_heads
  31. self.head_size = head_size
  32. self.scale = float(scale)
  33. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  34. self.sliding_window = sliding_window
  35. if alibi_slopes is not None:
  36. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  37. self.alibi_slopes = alibi_slopes
  38. assert self.num_heads % self.num_kv_heads == 0
  39. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  40. suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
  41. if head_size not in suppored_head_sizes:
  42. raise ValueError(
  43. f"Head size {head_size} is not supported by PagedAttention. "
  44. f"Supported head sizes are: {suppored_head_sizes}.")
  45. self.sliding_window = ((self.sliding_window, self.sliding_window) if
  46. self.sliding_window is not None else (-1, -1))
  47. def forward(
  48. self,
  49. query: torch.Tensor,
  50. key: torch.Tensor,
  51. value: torch.Tensor,
  52. key_cache: Optional[torch.Tensor],
  53. value_cache: Optional[torch.Tensor],
  54. input_metadata: InputMetadata,
  55. ) -> torch.Tensor:
  56. """Forward pass with FlashAttention and PagedAttention.
  57. Args:
  58. query: shape = [num_tokens, num_heads * head_size]
  59. key: shape = [num_tokens, num_kv_heads * head_size]
  60. value: shape = [num_tokens, num_kv_heads * head_size]
  61. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  62. block_size, x]
  63. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  64. block_size]
  65. input_metadata: metadata for the inputs.
  66. Returns:
  67. shape = [num_tokens, num_heads * head_size]
  68. """
  69. num_tokens, hidden_size = query.shape
  70. # Reshape the query, key, and value tensors.
  71. query = query.view(-1, self.num_heads, self.head_size)
  72. key = key.view(-1, self.num_kv_heads, self.head_size)
  73. value = value.view(-1, self.num_kv_heads, self.head_size)
  74. # Reshape the keys and values and store them in the cache.
  75. # If key_cache and value_cache are not provided, the new key and value
  76. # vectors will not be cached. This happens during the initial memory
  77. # profiling run.
  78. if key_cache is not None and value_cache is not None:
  79. PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
  80. value_cache, input_metadata)
  81. if input_metadata.is_prompt:
  82. # Prompt run.
  83. if (key_cache is None or value_cache is None
  84. or input_metadata.block_tables.numel() == 0):
  85. # normal attention
  86. # When block_tables are not filled, it means q and k are the
  87. # prompt, and they have the same length.
  88. output = flash_attn_varlen_func(
  89. q=query,
  90. k=key,
  91. v=value,
  92. cu_seqlens_q=input_metadata.seq_start_loc,
  93. cu_seqlens_k=input_metadata.seq_start_loc,
  94. max_seqlen_q=input_metadata.max_seq_len,
  95. max_seqlen_k=input_metadata.max_seq_len,
  96. softmax_scale=self.scale,
  97. causal=True,
  98. window_size=self.sliding_window,
  99. alibi_slopes=self.alibi_slopes,
  100. )
  101. else:
  102. # prefix-enabled attention
  103. output = PagedAttentionImpl.forward_prefix(
  104. query,
  105. key,
  106. value,
  107. key_cache,
  108. value_cache,
  109. input_metadata,
  110. self.alibi_slopes,
  111. )
  112. else:
  113. # Decoding run.
  114. output = PagedAttentionImpl.forward_decode(
  115. query,
  116. key_cache,
  117. value_cache,
  118. input_metadata,
  119. self.num_kv_heads,
  120. self.scale,
  121. self.alibi_slopes,
  122. )
  123. # Reshape the output tensor.
  124. return output.view(num_tokens, hidden_size)