enc_dec_attention.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """Multi-head attention for encoder-decoder models."""
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (
  7. BlockDiagonalCausalMask, )
  8. from aphrodite._C import cache_ops
  9. from aphrodite.modeling.metadata import InputMetadata
  10. from aphrodite.common.utils import is_hip
  11. from aphrodite.modeling.layers.attention import paged_attention
  12. _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  13. class EncDecAttention(nn.Module):
  14. def __init__(
  15. self,
  16. num_heads: int,
  17. head_size: int,
  18. scale: float,
  19. ) -> None:
  20. super().__init__()
  21. self.num_heads = num_heads
  22. self.head_size = head_size
  23. self.scale = float(scale)
  24. if self.head_size not in _SUPPORTED_HEAD_SIZES:
  25. raise ValueError(f"head_size ({self.head_size}) is not supported. "
  26. f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
  27. class EncoderAttention(EncDecAttention):
  28. def __init__(
  29. self,
  30. num_heads: int,
  31. head_size: int,
  32. scale: float,
  33. ) -> None:
  34. super().__init__(num_heads, head_size, scale)
  35. def forward(
  36. self,
  37. query: torch.Tensor,
  38. key: torch.Tensor,
  39. value: torch.Tensor,
  40. input_metadata: InputMetadata,
  41. ) -> torch.Tensor:
  42. """Encoder attention forward pass.
  43. Args:
  44. query: Query tensor.
  45. key: Key tensor.
  46. value: Value tensor.
  47. custom_bias: Custom bias tensor.
  48. Returns:
  49. Output tensor.
  50. """
  51. # query: [batch_size, seq_len, num_heads * head_size]
  52. # key: [batch_size, seq_len, num_heads * head_size]
  53. # value: [batch_size, seq_len, num_heads * head_size]
  54. # custom_bias: [batch_size, seq_len, seq_len]
  55. # output: [batch_size, seq_len, num_heads * head_size]
  56. assert input_metadata.is_prompt
  57. batch_size, seq_len, hidden_size = query.shape
  58. # Reshape the query, key, and value tensors.
  59. query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
  60. key = key.view(batch_size, seq_len, self.num_heads, self.head_size)
  61. value = value.view(batch_size, seq_len, self.num_heads, self.head_size)
  62. if input_metadata.attn_bias is None:
  63. input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens(
  64. [seq_len] * batch_size)
  65. input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len]
  66. # Normal attention
  67. out = xops.memory_efficient_attention_forward(
  68. query,
  69. key,
  70. value,
  71. attn_bias=input_metadata.attn_bias,
  72. p=0.0,
  73. scale=self.scale,
  74. op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
  75. (is_hip()) else None,
  76. )
  77. output = out.view(batch_size, seq_len, hidden_size)
  78. return output
  79. class DecoderAttention(EncDecAttention):
  80. def __init__(
  81. self,
  82. num_heads: int,
  83. head_size: int,
  84. scale: float,
  85. ) -> None:
  86. super().__init__(num_heads, head_size, scale)
  87. def forward(
  88. self,
  89. query: torch.Tensor,
  90. key: torch.Tensor,
  91. value: torch.Tensor,
  92. key_cache: Optional[torch.Tensor],
  93. value_cache: Optional[torch.Tensor],
  94. input_metadata: InputMetadata,
  95. ):
  96. """Decoder attention forward pass.
  97. Args:
  98. query: Query tensor.
  99. key: Key tensor.
  100. value: Value tensor.
  101. key_cache: Key cache tensor.
  102. value_cache: Value cache tensor.
  103. custom_bias: Custom bias tensor.
  104. Returns:
  105. Output tensor.
  106. """
  107. batch_size, seq_len, hidden_size = query.shape
  108. # Reshape the query, key, and value tensors.
  109. query = query.view(-1, self.num_heads, self.head_size)
  110. key = key.view(-1, self.num_heads, self.head_size)
  111. value = value.view(-1, self.num_heads, self.head_size)
  112. # Reshape the keys and values and store them in the cache.
  113. # If key_cache and value_cache are not provided, the new key and value
  114. # vectors will not be cached. This happens during the initial memory
  115. # profiling run.
  116. if key_cache is not None and value_cache is not None:
  117. cache_ops.reshape_and_cache(
  118. key, value, key_cache, value_cache,
  119. input_metadata.slot_mapping[:, -1].flatten().contiguous(),
  120. input_metadata.kv_cache_dtype)
  121. max_prompt_len = input_metadata.prompt_lens.max().item()
  122. block_size = value_cache.shape[3]
  123. prompt_table_len = (max_prompt_len + block_size - 1) // block_size
  124. block_tables = input_metadata.block_tables[:,
  125. prompt_table_len:].contiguous(
  126. )
  127. output = paged_attention(
  128. query=query,
  129. key_cache=key_cache,
  130. value_cache=value_cache,
  131. block_tables=block_tables,
  132. context_lens=input_metadata.context_lens,
  133. max_context_len=input_metadata.max_context_len,
  134. num_kv_heads=self.num_heads,
  135. scale=self.scale,
  136. alibi_slopes=None,
  137. custom_bias=input_metadata.attn_bias.to(torch.float32),
  138. kv_cache_dtype=input_metadata.kv_cache_dtype,
  139. kv_quant_param=input_metadata.kv_quant_params,
  140. )
  141. return output.view(batch_size, seq_len, hidden_size)
  142. class CrossAttention(EncDecAttention):
  143. def __init__(
  144. self,
  145. num_heads: int,
  146. head_size: int,
  147. scale: float,
  148. ) -> None:
  149. super().__init__(num_heads, head_size, scale)
  150. def forward(
  151. self,
  152. query: torch.Tensor,
  153. key: Optional[torch.Tensor],
  154. value: Optional[torch.Tensor],
  155. key_cache: Optional[torch.Tensor],
  156. value_cache: Optional[torch.Tensor],
  157. input_metadata: InputMetadata,
  158. ):
  159. """Cross attention forward pass.
  160. Args:
  161. query: Query tensor.
  162. key_cache: Key cache tensor.
  163. value_cache: Value cache tensor.
  164. input_metadata: Input metadata.
  165. key: Key tensor. Only needed in the first pass.
  166. value: Value tensor. Only needed in the first pass.
  167. custom_bias: Custom bias tensor.
  168. Returns:
  169. Output tensor.
  170. """
  171. batch_size, seq_len, hidden_size = query.shape
  172. # Reshape the query, key, and value tensors.
  173. query = query.view(-1, self.num_heads, self.head_size)
  174. if key is not None:
  175. key = key.view(-1, self.num_heads, self.head_size)
  176. if value is not None:
  177. value = value.view(-1, self.num_heads, self.head_size)
  178. # Reshape the keys and values and store them in the cache.
  179. # It only happens during the first pass.
  180. if (input_metadata.is_prompt and key_cache is not None
  181. and value_cache is not None):
  182. assert key is not None and value is not None
  183. cache_ops.reshape_and_cache(
  184. key,
  185. value,
  186. key_cache,
  187. value_cache,
  188. input_metadata.slot_mapping[:, :-1].flatten().contiguous(),
  189. input_metadata.kv_cache_dtype,
  190. )
  191. max_prompt_len = input_metadata.prompt_lens.int().max().item()
  192. block_size = value_cache.shape[3]
  193. prompt_table_len = (max_prompt_len + block_size - 1) // block_size
  194. block_tables = input_metadata.block_tables[:, :
  195. prompt_table_len].contiguous(
  196. )
  197. output = paged_attention(
  198. query=query,
  199. key_cache=key_cache,
  200. value_cache=value_cache,
  201. block_tables=block_tables,
  202. context_lens=input_metadata.prompt_lens.int(),
  203. max_context_len=max_prompt_len,
  204. num_kv_heads=self.num_heads,
  205. scale=self.scale,
  206. alibi_slopes=None,
  207. custom_bias=None,
  208. kv_cache_dtype=input_metadata.kv_cache_dtype,
  209. kv_quant_param=input_metadata.kv_quant_params,
  210. )
  211. return output.view(batch_size, seq_len, hidden_size)