torch_sdpa.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. """ Attention layer with torch scaled_dot_product_attention
  2. and PagedAttention."""
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, List, Optional, Tuple, Type
  5. import torch
  6. from torch.nn.functional import scaled_dot_product_attention
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionType)
  11. from aphrodite.attention.backends.utils import (CommonAttentionState,
  12. CommonMetadataBuilder)
  13. from aphrodite.attention.ops.paged_attn import PagedAttentionMetadata
  14. from aphrodite.common.utils import is_cpu
  15. if is_cpu():
  16. try:
  17. from aphrodite.attention.ops.ipex_attn import PagedAttention
  18. except ImportError:
  19. from aphrodite.attention.ops.paged_attn import PagedAttention
  20. else:
  21. from aphrodite.attention.ops.paged_attn import PagedAttention
  22. class TorchSDPABackend(AttentionBackend):
  23. @staticmethod
  24. def get_name() -> str:
  25. return "torch-sdpa"
  26. @staticmethod
  27. def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
  28. return TorchSDPABackendImpl
  29. @staticmethod
  30. def get_metadata_cls() -> Type["AttentionMetadata"]:
  31. return TorchSDPAMetadata
  32. @staticmethod
  33. def get_state_cls() -> Type["CommonAttentionState"]:
  34. return CommonAttentionState
  35. @staticmethod
  36. def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
  37. return TorchSDPAMetadataBuilder
  38. @staticmethod
  39. def get_kv_cache_shape(
  40. num_blocks: int,
  41. block_size: int,
  42. num_kv_heads: int,
  43. head_size: int,
  44. ) -> Tuple[int, ...]:
  45. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  46. num_kv_heads, head_size)
  47. @staticmethod
  48. def swap_blocks(
  49. src_kv_cache: torch.Tensor,
  50. dst_kv_cache: torch.Tensor,
  51. src_to_dst: torch.Tensor,
  52. ) -> None:
  53. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  54. @staticmethod
  55. def copy_blocks(
  56. kv_caches: List[torch.Tensor],
  57. src_to_dists: torch.Tensor,
  58. ) -> None:
  59. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  60. @dataclass
  61. class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
  62. """Metadata for TorchSDPABackend.
  63. """
  64. # Currently, input sequences can only contain all prompts
  65. # or all decoding. True if all sequences are prompts.
  66. is_prompt: bool
  67. slot_mapping: torch.Tensor
  68. seq_lens: Optional[List[int]]
  69. def __post_init__(self):
  70. # Set during the execution of the first attention op.
  71. # It is a list because it is needed to set per prompt
  72. # when alibi slopes is used. It is because of the limitation
  73. # from xformer API.
  74. # will not appear in the __repr__ and __init__
  75. self.attn_bias: Optional[List[torch.Tensor]] = None
  76. @property
  77. def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
  78. # Currently chunked prefill is not supported
  79. if self.num_decode_tokens == 0:
  80. assert self.num_prefills > 0
  81. return self
  82. return None
  83. @property
  84. def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
  85. # Currently chunked prefill is not supported
  86. if self.num_prefills > 0:
  87. assert self.num_decode_tokens == 0
  88. return None
  89. return self
  90. class TorchSDPAMetadataBuilder(CommonMetadataBuilder[TorchSDPAMetadata]):
  91. _metadata_cls = TorchSDPAMetadata
  92. class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
  93. def __init__(
  94. self,
  95. num_heads: int,
  96. head_size: int,
  97. scale: float,
  98. num_kv_heads: int,
  99. alibi_slopes: Optional[List[float]],
  100. sliding_window: Optional[int],
  101. kv_cache_dtype: str,
  102. blocksparse_params: Optional[Dict[str, Any]] = None,
  103. logits_soft_cap: Optional[float] = None,
  104. ) -> None:
  105. if blocksparse_params is not None:
  106. raise ValueError(
  107. "Torch SPDA does not support block-sparse attention.")
  108. if logits_soft_cap is not None:
  109. raise ValueError("Torch SPDA does not support logits soft cap.")
  110. self.num_heads = num_heads
  111. self.head_size = head_size
  112. self.scale = float(scale)
  113. self.num_kv_heads = num_kv_heads
  114. if alibi_slopes is not None:
  115. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  116. self.alibi_slopes = alibi_slopes
  117. self.sliding_window = sliding_window
  118. self.kv_cache_dtype = kv_cache_dtype
  119. assert self.num_heads % self.num_kv_heads == 0
  120. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  121. self.need_mask = (self.alibi_slopes is not None
  122. or self.sliding_window is not None)
  123. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  124. if head_size not in supported_head_sizes:
  125. raise ValueError(
  126. f"Head size {head_size} is not supported by PagedAttention. "
  127. f"Supported head sizes are: {supported_head_sizes}.")
  128. if kv_cache_dtype != "auto":
  129. raise NotImplementedError(
  130. "Torch SDPA backend does not support FP8 KV cache. "
  131. "Please use xFormers backend instead.")
  132. def forward(
  133. self,
  134. query: torch.Tensor,
  135. key: torch.Tensor,
  136. value: torch.Tensor,
  137. kv_cache: Optional[torch.Tensor],
  138. attn_metadata: TorchSDPAMetadata, # type: ignore
  139. k_scale: float = 1.0,
  140. v_scale: float = 1.0,
  141. attn_type: AttentionType = AttentionType.DECODER,
  142. ) -> torch.Tensor:
  143. """Forward pass with torch SDPA and PagedAttention.
  144. Args:
  145. query: shape = [num_tokens, num_heads * head_size]
  146. key: shape = [num_tokens, num_kv_heads * head_size]
  147. value: shape = [num_tokens, num_kv_heads * head_size]
  148. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  149. attn_metadata: Metadata for attention.
  150. Returns:
  151. shape = [num_tokens, num_heads * head_size]
  152. """
  153. if attn_type != AttentionType.DECODER:
  154. raise NotImplementedError("Encoder self-attention and "
  155. "encoder/decoder cross-attention "
  156. "are not implemented for "
  157. "TorchSDPABackendImpl")
  158. assert k_scale == 1.0 and v_scale == 1.0
  159. num_tokens, hidden_size = query.shape
  160. # Reshape the query, key, and value tensors.
  161. query = query.view(-1, self.num_heads, self.head_size)
  162. key = key.view(-1, self.num_kv_heads, self.head_size)
  163. value = value.view(-1, self.num_kv_heads, self.head_size)
  164. if kv_cache is not None:
  165. key_cache, value_cache = PagedAttention.split_kv_cache(
  166. kv_cache, self.num_kv_heads, self.head_size)
  167. PagedAttention.write_to_paged_cache(key, value, key_cache,
  168. value_cache,
  169. attn_metadata.slot_mapping,
  170. self.kv_cache_dtype, k_scale,
  171. v_scale)
  172. if attn_metadata.is_prompt:
  173. assert attn_metadata.seq_lens is not None
  174. if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
  175. if self.num_kv_heads != self.num_heads:
  176. key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
  177. value = value.repeat_interleave(self.num_queries_per_kv,
  178. dim=1)
  179. if attn_metadata.attn_bias is None:
  180. if self.alibi_slopes is not None:
  181. att_masks = _make_alibi_bias(
  182. self.alibi_slopes, query.dtype,
  183. attn_metadata.seq_lens) # type: ignore
  184. elif self.sliding_window is not None:
  185. att_masks = _make_sliding_window_bias(
  186. attn_metadata.seq_lens, self.sliding_window,
  187. query.dtype) # type: ignore
  188. else:
  189. att_masks = [None] * len(attn_metadata.seq_lens)
  190. attn_metadata.attn_bias = att_masks
  191. query = query.movedim(0, query.dim() - 2)
  192. key = key.movedim(0, key.dim() - 2)
  193. value = value.movedim(0, value.dim() - 2)
  194. start = 0
  195. output = torch.empty(
  196. (num_tokens, self.num_heads, self.head_size),
  197. dtype=query.dtype)
  198. for seq_len, mask in zip(attn_metadata.seq_lens,
  199. attn_metadata.attn_bias):
  200. end = start + seq_len
  201. sub_out = scaled_dot_product_attention(
  202. query[None, :, start:end, :],
  203. key[None, :, start:end, :],
  204. value[None, :, start:end, :],
  205. attn_mask=mask,
  206. dropout_p=0.0,
  207. is_causal=not self.need_mask,
  208. scale=self.scale).squeeze(0).movedim(
  209. query.dim() - 2, 0)
  210. output[start:end, :, :] = sub_out
  211. start = end
  212. else:
  213. # prefix-enabled attention
  214. raise RuntimeError(
  215. "Torch SDPA backend doesn't support prefix decoding.")
  216. else:
  217. # Decoding run.
  218. output = PagedAttention.forward_decode(
  219. query,
  220. key_cache,
  221. value_cache,
  222. attn_metadata.block_tables,
  223. attn_metadata.seq_lens_tensor,
  224. attn_metadata.max_decode_seq_len,
  225. self.kv_cache_dtype,
  226. self.num_kv_heads,
  227. self.scale,
  228. self.alibi_slopes,
  229. k_scale,
  230. v_scale,
  231. )
  232. # Reshape the output tensor.
  233. return output.view(-1, self.num_heads * self.head_size)
  234. def _make_alibi_bias(
  235. alibi_slopes: torch.Tensor,
  236. dtype: torch.dtype,
  237. seq_lens: List[int],
  238. ) -> List[torch.Tensor]:
  239. attn_biases = []
  240. for seq_len in seq_lens:
  241. bias = torch.arange(seq_len, dtype=dtype)
  242. # NOTE: HF uses
  243. # `bias = bias[None, :].repeat(seq_len, 1)`
  244. # here. We find that both biases give the same results, but
  245. # the bias below more accurately follows the original ALiBi
  246. # paper.
  247. bias = bias[None, :] - bias[:, None]
  248. num_heads = alibi_slopes.shape[0]
  249. bias = bias[None, :].repeat((num_heads, 1, 1))
  250. bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
  251. inf_mask = torch.empty(
  252. (1, seq_len, seq_len),
  253. dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
  254. attn_biases.append((bias + inf_mask).to(dtype))
  255. return attn_biases
  256. def _make_sliding_window_bias(
  257. seq_lens: List[int],
  258. window_size: Optional[int],
  259. dtype: torch.dtype,
  260. ) -> List[torch.Tensor]:
  261. attn_biases = []
  262. for seq_len in seq_lens:
  263. tensor = torch.full(
  264. (1, seq_len, seq_len),
  265. dtype=dtype,
  266. fill_value=1,
  267. )
  268. shift = 0
  269. mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
  270. if window_size is not None:
  271. mask = torch.triu(mask, diagonal=shift - window_size + 1)
  272. mask = torch.log(mask)
  273. attn_biases.append(mask.to(dtype))
  274. return attn_biases