pallas.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from dataclasses import dataclass
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. import torch_xla.experimental.custom_kernel # Required to register custom ops.
  5. from aphrodite.attention.backends.abstract import (AttentionBackend,
  6. AttentionImpl,
  7. AttentionMetadata,
  8. AttentionType)
  9. class PallasAttentionBackend(AttentionBackend):
  10. @staticmethod
  11. def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
  12. return PallasAttentionBackendImpl
  13. @staticmethod
  14. def get_metadata_cls() -> Type["PallasMetadata"]:
  15. return PallasMetadata
  16. @staticmethod
  17. def get_kv_cache_shape(
  18. num_blocks: int,
  19. block_size: int,
  20. num_kv_heads: int,
  21. head_size: int,
  22. ) -> Tuple[int, ...]:
  23. return (num_kv_heads, num_blocks, block_size, head_size)
  24. @staticmethod
  25. def swap_blocks(
  26. src_kv_cache: torch.Tensor,
  27. dst_kv_cache: torch.Tensor,
  28. src_to_dst: torch.Tensor,
  29. ) -> None:
  30. raise RuntimeError("swap_blocks is not used for the TPU backend.")
  31. @torch.compile(backend="openxla")
  32. @staticmethod
  33. def copy_blocks(
  34. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  35. src_to_dists: Tuple[torch.Tensor, torch.Tensor],
  36. ) -> None:
  37. src_indices, dst_indices = src_to_dists
  38. for k_cache, v_cache in kv_caches:
  39. torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
  40. k_cache[:, dst_indices] = k_cache[:, src_indices]
  41. torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
  42. v_cache[:, dst_indices] = v_cache[:, src_indices]
  43. @dataclass
  44. class PallasMetadata(AttentionMetadata):
  45. # Currently, input sequences can only contain all prefills
  46. # or all decoding.
  47. block_tables: Optional[torch.Tensor] = None
  48. context_lens: Optional[torch.Tensor] = None
  49. @property
  50. def prefill_metadata(self) -> Optional["PallasMetadata"]:
  51. if self.num_prefills == 0:
  52. return None
  53. assert self.num_decode_tokens == 0
  54. assert self.block_tables is None
  55. assert self.context_lens is None
  56. return self
  57. @property
  58. def decode_metadata(self) -> Optional["PallasMetadata"]:
  59. if self.num_decode_tokens == 0:
  60. return None
  61. assert self.num_prefills == 0
  62. assert self.num_prefill_tokens == 0
  63. assert self.block_tables is not None
  64. assert self.context_lens is not None
  65. return self
  66. class PallasAttentionBackendImpl(AttentionImpl):
  67. def __init__(
  68. self,
  69. num_heads: int,
  70. head_size: int,
  71. scale: float,
  72. num_kv_heads: int,
  73. alibi_slopes: Optional[List[float]],
  74. sliding_window: Optional[int],
  75. kv_cache_dtype: str,
  76. blocksparse_params: Optional[Dict[str, Any]] = None,
  77. logits_soft_cap: Optional[float] = None,
  78. ) -> None:
  79. self.num_heads = num_heads
  80. self.head_size = head_size
  81. self.scale = float(scale)
  82. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  83. assert self.num_heads % self.num_kv_heads == 0
  84. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  85. if head_size % 128 != 0:
  86. raise NotImplementedError("Head size must be a multiple of 128.")
  87. if alibi_slopes is not None:
  88. raise NotImplementedError("Alibi slopes is not supported.")
  89. if sliding_window is not None:
  90. raise NotImplementedError("Sliding window is not supported.")
  91. if kv_cache_dtype != "auto":
  92. raise NotImplementedError("FP8 KV cache dtype is not supported.")
  93. if blocksparse_params is not None:
  94. raise NotImplementedError("Blocksparse is not supported.")
  95. if logits_soft_cap is not None:
  96. raise NotImplementedError(
  97. "Attention logits soft-capping is not supported.")
  98. if torch_xla.tpu.version() < 4:
  99. raise NotImplementedError("TPU version must be 4 or higher.")
  100. self.megacore_mode = None
  101. tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
  102. if "lite" not in tpu_type:
  103. if self.num_kv_heads % 2 == 0:
  104. self.megacore_mode = "kv_head"
  105. else:
  106. # NOTE: If the batch size is not a multiple of 2, the
  107. # megacore mode will be None.
  108. self.megacore_mode = "batch"
  109. def forward(
  110. self,
  111. query: torch.Tensor,
  112. key: torch.Tensor,
  113. value: torch.Tensor,
  114. kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
  115. attn_metadata: PallasMetadata,
  116. k_scale: float = 1.0,
  117. v_scale: float = 1.0,
  118. attn_type: AttentionType = AttentionType.DECODER,
  119. ) -> torch.Tensor:
  120. """Forward pass with Pallas attention.
  121. Args:
  122. query: shape = [batch_size, seq_len, num_heads * head_size]
  123. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  124. value: shape = [batch_size, seq_len, num_kv_heads * head_size]
  125. key_cache = [num_kv_heads, num_blocks, block_size, head_size]
  126. value_cache = [num_kv_heads, num_blocks, block_size, head_size]
  127. attn_metadata: Metadata for attention.
  128. Returns:
  129. shape = [batch_size, seq_len, num_heads * head_size]
  130. """
  131. assert k_scale == 1.0 and v_scale == 1.0
  132. if attn_type != AttentionType.DECODER:
  133. raise NotImplementedError("Encoder self-attention and "
  134. "encoder/decoder cross-attention "
  135. "are not implemented for "
  136. "PallasAttentionBackendImpl")
  137. batch_size, seq_len, hidden_size = query.shape
  138. query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
  139. key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
  140. value = value.view(batch_size, seq_len, self.num_kv_heads,
  141. self.head_size)
  142. if kv_cache[0] is not None:
  143. slot_mapping = attn_metadata.slot_mapping
  144. key_cache, value_cache = kv_cache
  145. write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
  146. query = query * self.scale
  147. if attn_metadata.num_prefills > 0:
  148. assert seq_len % 16 == 0, (
  149. "Pallas FlashAttention kernel requires seq_len to be a "
  150. f"multiple of 16 but got {seq_len}")
  151. # Handle GQA/MQA.
  152. if self.num_kv_heads != self.num_heads:
  153. key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
  154. key = key.view(batch_size, seq_len, self.num_heads,
  155. self.head_size)
  156. value = value.repeat_interleave(self.num_queries_per_kv,
  157. dim=-2)
  158. value = value.view(batch_size, seq_len, self.num_heads,
  159. self.head_size)
  160. # FlashAttention requires [batch_size, num_heads, seq_len, d_model]
  161. # while the input is [batch_size, seq_len, num_heads, d_model].
  162. # Permute the input to match the required format.
  163. output = torch.ops.xla.flash_attention(
  164. query.permute(0, 2, 1, 3),
  165. key.permute(0, 2, 1, 3),
  166. value.permute(0, 2, 1, 3),
  167. True,
  168. )
  169. output = output.permute(0, 2, 1, 3)
  170. else:
  171. # Decoding run.
  172. assert kv_cache is not None
  173. pages_per_compute_block = 16 # TODO: Tune this value.
  174. if self.megacore_mode == "batch" and batch_size % 2 != 0:
  175. megacore_mode = None
  176. else:
  177. megacore_mode = self.megacore_mode
  178. # NOTE: A temporary workaround to avoid the error:
  179. # "xla::paged_attention() Expected a value of type 'str' for
  180. # argument 'megacore_mode' but instead found type 'NoneType'."
  181. if megacore_mode is not None:
  182. output = torch.ops.xla.paged_attention(
  183. query.squeeze(dim=1),
  184. key_cache,
  185. value_cache,
  186. attn_metadata.context_lens,
  187. attn_metadata.block_tables,
  188. pages_per_compute_block,
  189. megacore_mode=megacore_mode,
  190. )
  191. else:
  192. output = torch.ops.xla.paged_attention(
  193. query.squeeze(dim=1),
  194. key_cache,
  195. value_cache,
  196. attn_metadata.context_lens,
  197. attn_metadata.block_tables,
  198. pages_per_compute_block,
  199. )
  200. # Reshape the output tensor.
  201. return output.reshape(batch_size, seq_len, hidden_size)
  202. def write_to_kv_cache(
  203. key: torch.Tensor,
  204. value: torch.Tensor,
  205. key_cache: torch.Tensor,
  206. value_cache: torch.Tensor,
  207. slot_mapping: torch.Tensor,
  208. ) -> None:
  209. torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
  210. torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
  211. key = key.flatten(0, 2)
  212. value = value.flatten(0, 2)
  213. key_cache = key_cache.flatten(0, 2)
  214. value_cache = value_cache.flatten(0, 2)
  215. key_cache.index_copy_(0, slot_mapping, key)
  216. value_cache.index_copy_(0, slot_mapping, value)