ipex_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. """ Attention layer with torch scaled_dot_product_attention
  2. and PagedAttention."""
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, List, Optional, Tuple, Type
  5. import torch
  6. from aphrodite._ipex_ops import ipex_ops
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionType)
  11. from aphrodite.attention.backends.utils import (CommonAttentionState,
  12. CommonMetadataBuilder)
  13. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  14. PagedAttentionMetadata)
  15. _PARTITION_SIZE = 512
  16. class IpexAttnBackend(AttentionBackend):
  17. @staticmethod
  18. def get_name() -> str:
  19. return "ipex-attn"
  20. @staticmethod
  21. def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
  22. return IpexAttnBackendImpl
  23. @staticmethod
  24. def get_metadata_cls() -> Type["IpexAttnMetadata"]:
  25. return IpexAttnMetadata
  26. @staticmethod
  27. def get_builder_cls() -> Type["IpexAttnMetadataBuilder"]:
  28. return IpexAttnMetadataBuilder
  29. @staticmethod
  30. def get_state_cls() -> Type["CommonAttentionState"]:
  31. return CommonAttentionState
  32. @staticmethod
  33. def get_kv_cache_shape(
  34. num_blocks: int,
  35. block_size: int,
  36. num_kv_heads: int,
  37. head_size: int,
  38. ) -> Tuple[int, ...]:
  39. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  40. num_kv_heads, head_size)
  41. @staticmethod
  42. def swap_blocks(
  43. src_kv_cache: torch.Tensor,
  44. dst_kv_cache: torch.Tensor,
  45. src_to_dst: torch.Tensor,
  46. ) -> None:
  47. from aphrodite._ipex_ops import ipex_ops as ops
  48. ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  49. @staticmethod
  50. def copy_blocks(
  51. kv_caches: List[torch.Tensor],
  52. src_to_dists: torch.Tensor,
  53. ) -> None:
  54. from aphrodite._ipex_ops import ipex_ops as ops
  55. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  56. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  57. ops.copy_blocks(key_caches, value_caches, src_to_dists)
  58. @dataclass
  59. class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
  60. """Metadata for IpexAttnBackend.
  61. """
  62. # Currently, input sequences can only contain all prompts
  63. # or all decoding. True if all sequences are prompts.
  64. is_prompt: bool
  65. slot_mapping: torch.Tensor
  66. seq_lens: Optional[List[int]]
  67. seqlen_q: Optional[torch.Tensor]
  68. max_seqlen: Optional[int]
  69. def __post_init__(self):
  70. # Set during the execution of the first attention op.
  71. # It is a list because it is needed to set per prompt
  72. # when alibi slopes is used. It is because of the limitation
  73. # from xformer API.
  74. # will not appear in the __repr__ and __init__
  75. self.attn_bias: Optional[List[torch.Tensor]] = None
  76. @property
  77. def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
  78. # Currently chunked prefill is not supported
  79. if self.num_decode_tokens == 0:
  80. assert self.num_prefills > 0
  81. return self
  82. return None
  83. @property
  84. def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
  85. # Currently chunked prefill is not supported
  86. if self.num_prefills > 0:
  87. assert self.num_decode_tokens == 0
  88. return None
  89. return self
  90. class IpexAttnMetadataBuilder(CommonMetadataBuilder[IpexAttnMetadata]):
  91. _metadata_cls = IpexAttnMetadata
  92. class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
  93. def __init__(
  94. self,
  95. num_heads: int,
  96. head_size: int,
  97. scale: float,
  98. num_kv_heads: int,
  99. alibi_slopes: Optional[List[float]],
  100. sliding_window: Optional[int],
  101. kv_cache_dtype: str,
  102. blocksparse_params: Optional[Dict[str, Any]] = None,
  103. logits_soft_cap: Optional[float] = None,
  104. ) -> None:
  105. if blocksparse_params is not None:
  106. raise ValueError(
  107. "IPEX backend does not support block-sparse attention.")
  108. if logits_soft_cap is not None:
  109. raise ValueError("IPEX backend does not support logits_soft_cap.")
  110. self.num_heads = num_heads
  111. self.head_size = head_size
  112. self.scale = float(scale)
  113. self.num_kv_heads = num_kv_heads
  114. if alibi_slopes is not None:
  115. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  116. self.alibi_slopes = alibi_slopes
  117. self.sliding_window = sliding_window
  118. self.kv_cache_dtype = kv_cache_dtype
  119. assert self.num_heads % self.num_kv_heads == 0
  120. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  121. self.need_mask = (self.alibi_slopes is not None
  122. or self.sliding_window is not None)
  123. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  124. if head_size not in supported_head_sizes:
  125. raise ValueError(
  126. f"Head size {head_size} is not supported by PagedAttention. "
  127. f"Supported head sizes are: {supported_head_sizes}.")
  128. if kv_cache_dtype != "auto":
  129. raise NotImplementedError(
  130. "IPEX backend does not support FP8 KV cache. "
  131. "Please use xFormers backend instead.")
  132. def split_kv_cache(
  133. self,
  134. kv_cache: torch.Tensor,
  135. num_kv_heads: int,
  136. head_size: int,
  137. ) -> Tuple[torch.Tensor, torch.Tensor]:
  138. x = 1
  139. num_blocks = kv_cache.shape[1]
  140. key_cache = kv_cache[0]
  141. key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
  142. -1, x)
  143. value_cache = kv_cache[1]
  144. value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
  145. return key_cache, value_cache
  146. def forward(
  147. self,
  148. query: torch.Tensor,
  149. key: torch.Tensor,
  150. value: torch.Tensor,
  151. kv_cache: Optional[torch.Tensor],
  152. attn_metadata: IpexAttnMetadata, # type: ignore
  153. k_scale: float = 1.0,
  154. v_scale: float = 1.0,
  155. attn_type: AttentionType = AttentionType.DECODER,
  156. ) -> torch.Tensor:
  157. """Forward pass with IPEX varlen_attention and PagedAttention.
  158. Args:
  159. query: shape = [num_tokens, num_heads * head_size]
  160. key: shape = [num_tokens, num_kv_heads * head_size]
  161. value: shape = [num_tokens, num_kv_heads * head_size]
  162. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  163. attn_metadata: Metadata for attention.
  164. Returns:
  165. shape = [num_tokens, num_heads * head_size]
  166. """
  167. assert k_scale == 1.0 and v_scale == 1.0
  168. if attn_type != AttentionType.DECODER:
  169. raise NotImplementedError("Encoder self-attention and "
  170. "encoder/decoder cross-attention "
  171. "are not implemented for "
  172. "IpexAttnBackendImpl")
  173. num_tokens, hidden_size = query.shape
  174. # Reshape the query, key, and value tensors.
  175. query = query.view(-1, self.num_heads, self.head_size)
  176. key = key.view(-1, self.num_kv_heads, self.head_size)
  177. value = value.view(-1, self.num_kv_heads, self.head_size)
  178. if kv_cache is not None:
  179. key_cache, value_cache = self.split_kv_cache(
  180. kv_cache, self.num_kv_heads, self.head_size)
  181. ipex_ops.reshape_and_cache(
  182. key,
  183. value,
  184. key_cache,
  185. value_cache,
  186. attn_metadata.slot_mapping.flatten(),
  187. self.kv_cache_dtype,
  188. k_scale,
  189. v_scale,
  190. )
  191. if attn_metadata.is_prompt:
  192. assert attn_metadata.seq_lens is not None
  193. if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
  194. if self.num_kv_heads != self.num_heads:
  195. key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
  196. value = value.repeat_interleave(self.num_queries_per_kv,
  197. dim=1)
  198. if attn_metadata.attn_bias is None:
  199. if self.alibi_slopes is not None:
  200. att_masks = _make_alibi_bias(
  201. self.alibi_slopes, query.dtype,
  202. attn_metadata.seq_lens) # type: ignore
  203. elif self.sliding_window is not None:
  204. att_masks = _make_sliding_window_bias(
  205. attn_metadata.seq_lens, self.sliding_window,
  206. query.dtype) # type: ignore
  207. else:
  208. att_masks = _make_sliding_window_bias(
  209. attn_metadata.seq_lens, None, dtype=query.dtype)
  210. attn_metadata.attn_bias = att_masks
  211. output = torch.empty(
  212. (num_tokens, self.num_heads, self.head_size),
  213. dtype=query.dtype,
  214. device=query.device)
  215. ipex_ops.varlen_attention(query,
  216. key,
  217. value,
  218. output,
  219. attn_metadata.seqlen_q,
  220. attn_metadata.seqlen_q,
  221. attn_metadata.max_seqlen,
  222. attn_metadata.max_seqlen,
  223. pdropout=0.0,
  224. softmax_scale=self.scale,
  225. zero_tensors=False,
  226. is_causal=True,
  227. return_softmax=False,
  228. gen_=None)
  229. else:
  230. # prefix-enabled attention
  231. raise RuntimeError(
  232. "IPEX backend doesn't support prefix decoding.")
  233. else:
  234. # Decoding run.
  235. max_seq_len = attn_metadata.max_decode_seq_len
  236. output = torch.empty_like(query)
  237. block_size = value_cache.shape[3]
  238. num_seqs, num_heads, head_size = query.shape
  239. max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
  240. _PARTITION_SIZE)
  241. # NOTE: We use a simple heuristic to decide whether to use
  242. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  243. # V1 to avoid the overhead of reduction. Also, if the number of
  244. # sequences or heads is large, we use V1 since there is enough work
  245. # to parallelize.
  246. # TODO: Tune this heuristic.
  247. # For context len > 8192, use V2 kernel to avoid shared memory
  248. # shortage.
  249. use_v1 = (max_seq_len <= 8192 and
  250. (max_num_partitions == 1 or num_seqs * num_heads > 512))
  251. if use_v1:
  252. # Run PagedAttention V1.
  253. ipex_ops.paged_attention_v1(
  254. output,
  255. query,
  256. key_cache,
  257. value_cache,
  258. self.num_kv_heads,
  259. self.scale,
  260. attn_metadata.block_tables,
  261. attn_metadata.seq_lens_tensor,
  262. block_size,
  263. max_seq_len,
  264. self.alibi_slopes,
  265. self.kv_cache_dtype,
  266. k_scale,
  267. v_scale,
  268. )
  269. else:
  270. # Run PagedAttention V2.
  271. assert _PARTITION_SIZE % block_size == 0
  272. tmp_output = torch.empty(
  273. size=(num_seqs, num_heads, max_num_partitions, head_size),
  274. dtype=output.dtype,
  275. device=output.device,
  276. )
  277. exp_sums = torch.empty(
  278. size=(num_seqs, num_heads, max_num_partitions),
  279. dtype=torch.float32,
  280. device=output.device,
  281. )
  282. max_logits = torch.empty_like(exp_sums)
  283. ipex_ops.paged_attention_v2(
  284. output,
  285. exp_sums,
  286. max_logits,
  287. tmp_output,
  288. query,
  289. key_cache,
  290. value_cache,
  291. self.num_kv_heads,
  292. self.scale,
  293. attn_metadata.block_tables,
  294. attn_metadata.seq_lens_tensor,
  295. block_size,
  296. max_seq_len,
  297. self.alibi_slopes,
  298. self.kv_cache_dtype,
  299. k_scale,
  300. v_scale,
  301. )
  302. # Reshape the output tensor.
  303. return output.view(-1, self.num_heads * self.head_size)
  304. def _make_alibi_bias(
  305. alibi_slopes: torch.Tensor,
  306. dtype: torch.dtype,
  307. seq_lens: List[int],
  308. ) -> List[torch.Tensor]:
  309. attn_biases = []
  310. for seq_len in seq_lens:
  311. bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
  312. # NOTE: HF uses
  313. # `bias = bias[None, :].repeat(seq_len, 1)`
  314. # here. We find that both biases give the same results, but
  315. # the bias below more accurately follows the original ALiBi
  316. # paper.
  317. bias = bias[None, :] - bias[:, None]
  318. num_heads = alibi_slopes.shape[0]
  319. bias = bias[None, :].repeat((num_heads, 1, 1))
  320. bias.mul_(alibi_slopes[:, None, None])
  321. inf_mask = torch.empty(
  322. (1, seq_len, seq_len),
  323. dtype=bias.dtype,
  324. device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
  325. attn_biases.append((bias + inf_mask).to(dtype))
  326. return attn_biases
  327. def _make_sliding_window_bias(
  328. seq_lens: List[int],
  329. window_size: Optional[int],
  330. dtype: torch.dtype,
  331. ) -> List[torch.Tensor]:
  332. attn_biases = []
  333. for seq_len in seq_lens:
  334. tensor = torch.full(
  335. (1, seq_len, seq_len),
  336. dtype=dtype,
  337. fill_value=1,
  338. )
  339. shift = 0
  340. mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
  341. if window_size is not None:
  342. mask = torch.triu(mask, diagonal=shift - window_size + 1)
  343. mask = torch.log(mask)
  344. attn_biases.append(mask.to(dtype))
  345. return attn_biases