pallas.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. from dataclasses import dataclass
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. import torch_xla.experimental.custom_kernel # Required to register custom ops.
  5. from aphrodite.attention.backends.abstract import (AttentionBackend,
  6. AttentionImpl,
  7. AttentionMetadata,
  8. AttentionType)
  9. from aphrodite.attention.backends.utils import CommonAttentionState
  10. class PallasAttentionBackend(AttentionBackend):
  11. @staticmethod
  12. def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
  13. return PallasAttentionBackendImpl
  14. @staticmethod
  15. def get_metadata_cls() -> Type["PallasMetadata"]:
  16. return PallasMetadata
  17. @staticmethod
  18. def get_state_cls() -> Type["CommonAttentionState"]:
  19. return CommonAttentionState
  20. @staticmethod
  21. def get_kv_cache_shape(
  22. num_blocks: int,
  23. block_size: int,
  24. num_kv_heads: int,
  25. head_size: int,
  26. ) -> Tuple[int, ...]:
  27. return (num_kv_heads, num_blocks, block_size, head_size)
  28. @staticmethod
  29. def swap_blocks(
  30. src_kv_cache: torch.Tensor,
  31. dst_kv_cache: torch.Tensor,
  32. src_to_dst: torch.Tensor,
  33. ) -> None:
  34. raise RuntimeError("swap_blocks is not used for the TPU backend.")
  35. @torch.compile(backend="openxla")
  36. @staticmethod
  37. def copy_blocks(
  38. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  39. src_to_dists: Tuple[torch.Tensor, torch.Tensor],
  40. ) -> None:
  41. src_indices, dst_indices = src_to_dists
  42. for k_cache, v_cache in kv_caches:
  43. torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
  44. k_cache[:, dst_indices] = k_cache[:, src_indices]
  45. torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
  46. v_cache[:, dst_indices] = v_cache[:, src_indices]
  47. @dataclass
  48. class PallasMetadata(AttentionMetadata):
  49. # Currently, input sequences can only contain all prefills
  50. # or all decoding.
  51. block_tables: Optional[torch.Tensor] = None
  52. context_lens: Optional[torch.Tensor] = None
  53. @property
  54. def prefill_metadata(self) -> Optional["PallasMetadata"]:
  55. if self.num_prefills == 0:
  56. return None
  57. assert self.num_decode_tokens == 0
  58. assert self.block_tables is None
  59. assert self.context_lens is None
  60. return self
  61. @property
  62. def decode_metadata(self) -> Optional["PallasMetadata"]:
  63. if self.num_decode_tokens == 0:
  64. return None
  65. assert self.num_prefills == 0
  66. assert self.num_prefill_tokens == 0
  67. assert self.block_tables is not None
  68. assert self.context_lens is not None
  69. return self
  70. class PallasAttentionBackendImpl(AttentionImpl):
  71. def __init__(
  72. self,
  73. num_heads: int,
  74. head_size: int,
  75. scale: float,
  76. num_kv_heads: int,
  77. alibi_slopes: Optional[List[float]],
  78. sliding_window: Optional[int],
  79. kv_cache_dtype: str,
  80. blocksparse_params: Optional[Dict[str, Any]] = None,
  81. logits_soft_cap: Optional[float] = None,
  82. ) -> None:
  83. self.num_heads = num_heads
  84. self.head_size = head_size
  85. self.scale = float(scale)
  86. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  87. assert self.num_heads % self.num_kv_heads == 0
  88. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  89. if head_size % 128 != 0:
  90. raise NotImplementedError("Head size must be a multiple of 128.")
  91. if alibi_slopes is not None:
  92. raise NotImplementedError("Alibi slopes is not supported.")
  93. if sliding_window is not None:
  94. raise NotImplementedError("Sliding window is not supported.")
  95. if kv_cache_dtype != "auto":
  96. raise NotImplementedError("FP8 KV cache dtype is not supported.")
  97. if blocksparse_params is not None:
  98. raise NotImplementedError("Blocksparse is not supported.")
  99. if logits_soft_cap is not None:
  100. raise NotImplementedError(
  101. "Attention logits soft-capping is not supported.")
  102. if torch_xla.tpu.version() < 4:
  103. raise NotImplementedError("TPU version must be 4 or higher.")
  104. self.megacore_mode = None
  105. tpu_env = torch_xla.tpu.get_tpu_env()
  106. tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
  107. or tpu_env.get("TYPE", None)
  108. or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
  109. assert tpu_type is not None
  110. tpu_type = tpu_type.lower()
  111. if "lite" not in tpu_type:
  112. if self.num_kv_heads % 2 == 0:
  113. self.megacore_mode = "kv_head"
  114. else:
  115. # NOTE: If the batch size is not a multiple of 2, the
  116. # megacore mode will be None.
  117. self.megacore_mode = "batch"
  118. def forward(
  119. self,
  120. query: torch.Tensor,
  121. key: torch.Tensor,
  122. value: torch.Tensor,
  123. kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
  124. attn_metadata: PallasMetadata,
  125. k_scale: float = 1.0,
  126. v_scale: float = 1.0,
  127. attn_type: AttentionType = AttentionType.DECODER,
  128. ) -> torch.Tensor:
  129. """Forward pass with Pallas attention.
  130. Args:
  131. query: shape = [batch_size, seq_len, num_heads * head_size]
  132. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  133. value: shape = [batch_size, seq_len, num_kv_heads * head_size]
  134. key_cache = [num_kv_heads, num_blocks, block_size, head_size]
  135. value_cache = [num_kv_heads, num_blocks, block_size, head_size]
  136. attn_metadata: Metadata for attention.
  137. Returns:
  138. shape = [batch_size, seq_len, num_heads * head_size]
  139. """
  140. assert k_scale == 1.0 and v_scale == 1.0
  141. if attn_type != AttentionType.DECODER:
  142. raise NotImplementedError("Encoder self-attention and "
  143. "encoder/decoder cross-attention "
  144. "are not implemented for "
  145. "PallasAttentionBackendImpl")
  146. batch_size, seq_len, hidden_size = query.shape
  147. query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
  148. key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
  149. value = value.view(batch_size, seq_len, self.num_kv_heads,
  150. self.head_size)
  151. if kv_cache[0] is not None:
  152. slot_mapping = attn_metadata.slot_mapping
  153. key_cache, value_cache = kv_cache
  154. write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
  155. query = query * self.scale
  156. if attn_metadata.num_prefills > 0:
  157. assert seq_len % 16 == 0, (
  158. "Pallas FlashAttention kernel requires seq_len to be a "
  159. f"multiple of 16 but got {seq_len}")
  160. # Handle GQA/MQA.
  161. if self.num_kv_heads != self.num_heads:
  162. key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
  163. key = key.view(batch_size, seq_len, self.num_heads,
  164. self.head_size)
  165. value = value.repeat_interleave(self.num_queries_per_kv,
  166. dim=-2)
  167. value = value.view(batch_size, seq_len, self.num_heads,
  168. self.head_size)
  169. # FlashAttention requires [batch_size, num_heads, seq_len, d_model]
  170. # while the input is [batch_size, seq_len, num_heads, d_model].
  171. # Permute the input to match the required format.
  172. output = torch.ops.xla.flash_attention(
  173. query.permute(0, 2, 1, 3),
  174. key.permute(0, 2, 1, 3),
  175. value.permute(0, 2, 1, 3),
  176. True,
  177. )
  178. output = output.permute(0, 2, 1, 3)
  179. else:
  180. # Decoding run.
  181. assert kv_cache is not None
  182. pages_per_compute_block = 16 # TODO: Tune this value.
  183. if self.megacore_mode == "batch" and batch_size % 2 != 0:
  184. megacore_mode = None
  185. else:
  186. megacore_mode = self.megacore_mode
  187. # NOTE: A temporary workaround to avoid the error:
  188. # "xla::paged_attention() Expected a value of type 'str' for
  189. # argument 'megacore_mode' but instead found type 'NoneType'."
  190. if megacore_mode is not None:
  191. output = torch.ops.xla.paged_attention(
  192. query.squeeze(dim=1),
  193. key_cache,
  194. value_cache,
  195. attn_metadata.context_lens,
  196. attn_metadata.block_tables,
  197. pages_per_compute_block,
  198. megacore_mode=megacore_mode,
  199. )
  200. else:
  201. output = torch.ops.xla.paged_attention(
  202. query.squeeze(dim=1),
  203. key_cache,
  204. value_cache,
  205. attn_metadata.context_lens,
  206. attn_metadata.block_tables,
  207. pages_per_compute_block,
  208. )
  209. # Reshape the output tensor.
  210. return output.reshape(batch_size, seq_len, hidden_size)
  211. def write_to_kv_cache(
  212. key: torch.Tensor,
  213. value: torch.Tensor,
  214. key_cache: torch.Tensor,
  215. value_cache: torch.Tensor,
  216. slot_mapping: torch.Tensor,
  217. ) -> None:
  218. torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
  219. torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
  220. key = key.flatten(0, 2)
  221. value = value.flatten(0, 2)
  222. key_cache = key_cache.flatten(0, 2)
  223. value_cache = value_cache.flatten(0, 2)
  224. key_cache.index_copy_(0, slot_mapping, key)
  225. value_cache.index_copy_(0, slot_mapping, value)