torch_sdpa.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """ Attention layer with torch scaled_dot_product_attention
  2. and PagedAttention."""
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple, Type
  5. import torch
  6. from torch.nn.functional import scaled_dot_product_attention
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionMetadataPerStage)
  11. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  12. PagedAttentionMetadata)
  13. class TorchSDPABackend(AttentionBackend):
  14. @staticmethod
  15. def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
  16. return TorchSDPABackendImpl
  17. @staticmethod
  18. def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
  19. return TorchSDPAMetadata(*args, **kwargs)
  20. @staticmethod
  21. def get_kv_cache_shape(
  22. num_blocks: int,
  23. block_size: int,
  24. num_kv_heads: int,
  25. head_size: int,
  26. ) -> Tuple[int, ...]:
  27. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  28. num_kv_heads, head_size)
  29. @staticmethod
  30. def swap_blocks(
  31. src_kv_cache: torch.Tensor,
  32. dst_kv_cache: torch.Tensor,
  33. src_to_dst: Dict[int, int],
  34. ) -> None:
  35. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  36. @staticmethod
  37. def copy_blocks(
  38. kv_caches: List[torch.Tensor],
  39. src_to_dists: Dict[int, List[int]],
  40. ) -> None:
  41. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  42. @dataclass
  43. class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
  44. AttentionMetadataPerStage):
  45. """Metadata for TorchSDPABackend.
  46. """
  47. # Currently, input sequences can only contain all prompts
  48. # or all decoding. True if all sequences are prompts.
  49. is_prompt: bool
  50. slot_mapping: torch.Tensor
  51. prompt_lens: Optional[List[int]]
  52. def __post_init__(self):
  53. # Set during the execution of the first attention op.
  54. # It is a list because it is needed to set per prompt
  55. # when alibi slopes is used. It is because of the limitation
  56. # from xformer API.
  57. # will not appear in the __repr__ and __init__
  58. self.attn_bias: Optional[List[torch.Tensor]] = None
  59. class TorchSDPABackendImpl(AttentionImpl):
  60. def __init__(
  61. self,
  62. num_heads: int,
  63. head_size: int,
  64. scale: float,
  65. num_kv_heads: Optional[int] = None,
  66. alibi_slopes: Optional[List[float]] = None,
  67. sliding_window: Optional[int] = None,
  68. ) -> None:
  69. self.num_heads = num_heads
  70. self.head_size = head_size
  71. self.scale = float(scale)
  72. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  73. self.sliding_window = sliding_window
  74. if alibi_slopes is not None:
  75. assert len(alibi_slopes) == num_heads
  76. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  77. self.alibi_slopes = alibi_slopes
  78. self.need_mask = (self.alibi_slopes is not None
  79. or self.sliding_window is not None)
  80. assert self.num_heads % self.num_kv_heads == 0
  81. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  82. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  83. if head_size not in suppored_head_sizes:
  84. raise ValueError(
  85. f"Head size {head_size} is not supported by PagedAttention. "
  86. f"Supported head sizes are: {suppored_head_sizes}.")
  87. def forward(
  88. self,
  89. query: torch.Tensor,
  90. key: torch.Tensor,
  91. value: torch.Tensor,
  92. kv_cache: Optional[torch.Tensor],
  93. attn_metadata: TorchSDPAMetadata, # type: ignore
  94. kv_scale: float,
  95. ) -> torch.Tensor:
  96. """Forward pass with torch SDPA and PagedAttention.
  97. Args:
  98. query: shape = [num_tokens, num_heads * head_size]
  99. key: shape = [num_tokens, num_kv_heads * head_size]
  100. value: shape = [num_tokens, num_kv_heads * head_size]
  101. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  102. attn_metadata: Metadata for attention.
  103. Returns:
  104. shape = [num_tokens, num_heads * head_size]
  105. """
  106. num_tokens, hidden_size = query.shape
  107. # Reshape the query, key, and value tensors.
  108. query = query.view(-1, self.num_heads, self.head_size)
  109. key = key.view(-1, self.num_kv_heads, self.head_size)
  110. value = value.view(-1, self.num_kv_heads, self.head_size)
  111. if kv_cache is not None:
  112. key_cache, value_cache = PagedAttention.split_kv_cache(
  113. kv_cache, self.num_kv_heads, self.head_size)
  114. PagedAttention.write_to_paged_cache(key, value, key_cache,
  115. value_cache,
  116. attn_metadata.slot_mapping,
  117. attn_metadata.kv_cache_dtype,
  118. kv_scale)
  119. if attn_metadata.is_prompt:
  120. assert attn_metadata.prompt_lens is not None
  121. if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
  122. if self.num_kv_heads != self.num_heads:
  123. key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
  124. value = value.repeat_interleave(self.num_queries_per_kv,
  125. dim=1)
  126. if attn_metadata.attn_bias is None:
  127. if self.alibi_slopes is not None:
  128. att_masks = _make_alibi_bias(
  129. self.alibi_slopes, query.dtype,
  130. attn_metadata.prompt_lens) # type: ignore
  131. elif self.sliding_window is not None:
  132. att_masks = _make_sliding_window_bias(
  133. attn_metadata.prompt_lens, self.sliding_window,
  134. query.dtype) # type: ignore
  135. else:
  136. att_masks = [None] * len(attn_metadata.prompt_lens)
  137. attn_metadata.attn_bias = att_masks
  138. query = query.movedim(0, query.dim() - 2)
  139. key = key.movedim(0, key.dim() - 2)
  140. value = value.movedim(0, value.dim() - 2)
  141. start = 0
  142. output = torch.empty(
  143. (num_tokens, self.num_heads, self.head_size),
  144. dtype=query.dtype)
  145. for prompt_len, mask in zip(attn_metadata.prompt_lens,
  146. attn_metadata.attn_bias):
  147. end = start + prompt_len
  148. sub_out = scaled_dot_product_attention(
  149. query[:, start:end, :],
  150. key[:, start:end, :],
  151. value[:, start:end, :],
  152. attn_mask=mask,
  153. dropout_p=0.0,
  154. is_causal=not self.need_mask,
  155. scale=self.scale).movedim(query.dim() - 2, 0)
  156. output[start:end, :, :] = sub_out
  157. start = end
  158. else:
  159. # prefix-enabled attention
  160. raise RuntimeError(
  161. "Torch SDPA backend doesn't support prefix decoding.")
  162. else:
  163. # Decoding run.
  164. output = PagedAttention.forward_decode(
  165. query,
  166. key_cache,
  167. value_cache,
  168. attn_metadata.block_tables,
  169. attn_metadata.context_lens,
  170. attn_metadata.max_context_len,
  171. attn_metadata.kv_cache_dtype,
  172. self.num_kv_heads,
  173. self.scale,
  174. self.alibi_slopes,
  175. kv_scale,
  176. )
  177. # Reshape the output tensor.
  178. return output.view(-1, self.num_heads * self.head_size)
  179. def _make_alibi_bias(
  180. alibi_slopes: torch.Tensor,
  181. dtype: torch.dtype,
  182. prompt_lens: List[int],
  183. ) -> List[torch.Tensor]:
  184. attn_biases = []
  185. for prompt_len in prompt_lens:
  186. bias = torch.arange(prompt_len, dtype=dtype)
  187. # NOTE(zhuohan): HF uses
  188. # `bias = bias[None, :].repeat(prompt_len, 1)`
  189. # here. We find that both biases give the same results, but
  190. # the bias below more accurately follows the original ALiBi
  191. # paper.
  192. bias = bias[None, :] - bias[:, None]
  193. num_heads = alibi_slopes.shape[0]
  194. bias = bias[None, :].repeat((num_heads, 1, 1))
  195. bias.mul_(alibi_slopes[:, None, None])
  196. inf_mask = torch.empty(
  197. (1, prompt_len, prompt_len),
  198. dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
  199. attn_biases.append((bias + inf_mask).to(dtype))
  200. return attn_biases
  201. def _make_sliding_window_bias(
  202. prompt_lens: List[int],
  203. window_size: Optional[int],
  204. dtype: torch.dtype,
  205. ) -> List[torch.Tensor]:
  206. attn_biases = []
  207. for prompt_len in prompt_lens:
  208. tensor = torch.full(
  209. (1, prompt_len, prompt_len),
  210. dtype=dtype,
  211. fill_value=1,
  212. )
  213. shift = 0
  214. mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
  215. if window_size is not None:
  216. mask = torch.triu(mask, diagonal=shift - window_size + 1)
  217. mask = torch.log(mask)
  218. attn_biases.append(mask.to(dtype))
  219. return attn_biases