1
0

placeholder_attn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.attention.backends.abstract import (AttentionBackend,
  5. AttentionImpl,
  6. AttentionMetadata,
  7. AttentionMetadataBuilder)
  8. from aphrodite.attention.backends.utils import CommonAttentionState
  9. if TYPE_CHECKING:
  10. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  11. # Placeholder attention backend for models like Mamba and embedding models that
  12. # lack attention.
  13. class PlaceholderAttentionBackend(AttentionBackend):
  14. """Placeholder backend for when no attention is needed."""
  15. @staticmethod
  16. def get_name() -> str:
  17. return "No attention"
  18. @staticmethod
  19. def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
  20. return PlaceholderAttentionImpl
  21. @staticmethod
  22. def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
  23. return PlaceholderAttentionMetadataBuilder
  24. @staticmethod
  25. def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
  26. return PlaceholderAttentionMetadata
  27. @staticmethod
  28. def get_state_cls() -> Type["CommonAttentionState"]:
  29. return CommonAttentionState
  30. @staticmethod
  31. def get_kv_cache_shape(
  32. num_blocks: int,
  33. block_size: int,
  34. num_kv_heads: int,
  35. head_size: int,
  36. ) -> Tuple[int, ...]:
  37. return (1, 1, 1, 1, 1)
  38. @staticmethod
  39. def swap_blocks(
  40. src_kv_cache: torch.Tensor,
  41. dst_kv_cache: torch.Tensor,
  42. src_to_dst: torch.Tensor,
  43. ) -> None:
  44. return
  45. @staticmethod
  46. def copy_blocks(
  47. kv_caches: List[torch.Tensor],
  48. src_to_dists: torch.Tensor,
  49. ) -> None:
  50. return
  51. @dataclass
  52. class PlaceholderAttentionMetadata(AttentionMetadata):
  53. """Attention metadata for prefill and decode batched together."""
  54. # (batch_size,). The sequence length per sequence. Sequence length means
  55. # the computed tokens + new tokens None if it is a decoding.
  56. seq_lens: Optional[List[int]]
  57. # seq_lens stored as a tensor.
  58. seq_lens_tensor: Optional[torch.Tensor]
  59. # Maximum query length in the batch. None for decoding.
  60. max_query_len: Optional[int]
  61. # Maximum sequence length among prefill batch. 0 if there are decoding
  62. # requests only.
  63. max_prefill_seq_len: int
  64. # Maximum sequence length among decode batch. 0 if there are prefill
  65. # requests only.
  66. max_decode_seq_len: int
  67. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  68. # the batch, used to index into subquery. E.g., if the subquery length
  69. # is [4, 6], it is [0, 4, 10].
  70. query_start_loc: Optional[torch.Tensor]
  71. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  72. # the batch, used to index into sequence. E.g., if the sequence length is
  73. # [4, 6], it is [0, 4, 10].
  74. seq_start_loc: Optional[torch.Tensor]
  75. # (batch_size,) A tensor of context lengths (tokens that are computed
  76. # so far).
  77. context_lens_tensor: Optional[torch.Tensor]
  78. # (batch_size, max_blocks_per_seq).
  79. # Block addresses per sequence. (Seq id -> list of physical block)
  80. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  81. # in the kv cache. Each block can contain up to block_size tokens.
  82. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  83. # captured.
  84. block_tables: Optional[torch.Tensor]
  85. # Whether or not if cuda graph is enabled.
  86. # Cuda-graph is currently enabled for decoding only.
  87. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
  88. use_cuda_graph: bool
  89. _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
  90. _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
  91. @property
  92. def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
  93. if self.num_prefills == 0:
  94. return None
  95. if self._cached_prefill_metadata is not None:
  96. return self._cached_prefill_metadata
  97. assert self.seq_lens is not None
  98. assert self.seq_lens_tensor is not None
  99. assert self.query_start_loc is not None
  100. assert self.context_lens_tensor is not None
  101. assert self.seq_start_loc is not None
  102. # Placeholders
  103. slot_mapping = torch.empty(0)
  104. block_tables = torch.empty(0)
  105. self._cached_prefill_metadata = PlaceholderAttentionMetadata(
  106. num_prefills=self.num_prefills,
  107. num_prefill_tokens=self.num_prefill_tokens,
  108. num_decode_tokens=0,
  109. slot_mapping=slot_mapping,
  110. seq_lens=self.seq_lens[:self.num_prefills],
  111. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  112. max_query_len=self.max_query_len,
  113. max_prefill_seq_len=self.max_prefill_seq_len,
  114. max_decode_seq_len=0,
  115. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  116. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  117. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  118. block_tables=block_tables,
  119. use_cuda_graph=False,
  120. )
  121. return self._cached_prefill_metadata
  122. @property
  123. def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
  124. if self.num_decode_tokens == 0:
  125. return None
  126. if self._cached_decode_metadata is not None:
  127. return self._cached_decode_metadata
  128. assert self.seq_lens_tensor is not None
  129. # Placeholders
  130. slot_mapping = torch.empty(0)
  131. block_tables = torch.empty(0)
  132. self._cached_decode_metadata = PlaceholderAttentionMetadata(
  133. num_prefills=0,
  134. num_prefill_tokens=0,
  135. num_decode_tokens=self.num_decode_tokens,
  136. slot_mapping=slot_mapping,
  137. seq_lens=None,
  138. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  139. max_query_len=None,
  140. max_prefill_seq_len=0,
  141. max_decode_seq_len=self.max_decode_seq_len,
  142. query_start_loc=None,
  143. seq_start_loc=None,
  144. context_lens_tensor=None,
  145. block_tables=block_tables,
  146. use_cuda_graph=self.use_cuda_graph,
  147. )
  148. return self._cached_decode_metadata
  149. class PlaceholderAttentionMetadataBuilder(
  150. AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
  151. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  152. self.prefill_seq_lens: List[int] = []
  153. self.context_lens: List[int] = []
  154. self.curr_seq_lens: List[int] = []
  155. self.num_prefills = 0
  156. self.num_prefill_tokens = 0
  157. self.num_decode_tokens = 0
  158. self.input_builder = input_builder
  159. self.runner = input_builder.runner
  160. def _add_seq_group(
  161. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  162. chunked_prefill_enabled: bool):
  163. """Add a sequence group to the metadata. Specifically update/append
  164. 1. context length.
  165. """
  166. is_prompt = inter_data.is_prompt
  167. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  168. curr_sliding_window_block) in zip(
  169. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  170. inter_data.orig_seq_lens, inter_data.seq_lens,
  171. inter_data.query_lens, inter_data.context_lens,
  172. inter_data.curr_sliding_window_blocks):
  173. self.context_lens.append(context_len)
  174. if is_prompt:
  175. self.num_prefills += 1
  176. self.num_prefill_tokens += token_len
  177. self.prefill_seq_lens.append(seq_len)
  178. else:
  179. assert query_len == 1, (
  180. "seq_len: {}, context_len: {}, query_len: {}".format(
  181. seq_len, context_len, query_len))
  182. self.num_decode_tokens += query_len
  183. self.curr_seq_lens.append(curr_seq_len)
  184. def build(self, seq_lens: List[int], query_lens: List[int],
  185. cuda_graph_pad_size: int, batch_size: int):
  186. """Build attention metadata with on-device tensors.
  187. Args:
  188. seq_lens: The maybe padded sequence lengths of the input sequences.
  189. query_lens: The query lengths of the input sequences.
  190. cuda_graph_pad_size: The padding size for cuda graph.
  191. -1 if cuda graph is not used.
  192. batch_size: The maybe padded batch size.
  193. """
  194. for inter_data in self.input_builder.inter_data_list:
  195. self._add_seq_group(inter_data,
  196. self.input_builder.chunked_prefill_enabled)
  197. device = self.runner.device
  198. use_captured_graph = cuda_graph_pad_size != -1
  199. logits_soft_cap = getattr(self.runner.model_config.hf_config,
  200. "attn_logit_softcapping", None)
  201. if logits_soft_cap is not None:
  202. raise ValueError(
  203. "Please use Flashinfer backend for models with logits_soft_cap"
  204. " (i.e., Gemma-2). Otherwise, the output might be wrong."
  205. " Set Flashinfer backend by "
  206. "export APHRODITE_ATTENTION_BACKEND=FLASHINFER.")
  207. max_query_len = max(query_lens)
  208. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  209. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  210. num_decode_tokens = self.num_decode_tokens
  211. if use_captured_graph:
  212. num_decode_tokens = batch_size
  213. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  214. context_lens_tensor = torch.tensor(self.context_lens,
  215. dtype=torch.int,
  216. device=device)
  217. seq_lens_tensor = torch.tensor(seq_lens,
  218. dtype=torch.int,
  219. device=device)
  220. query_lens_tensor = torch.tensor(query_lens,
  221. dtype=torch.long,
  222. device=device)
  223. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  224. dtype=torch.int32,
  225. device=device)
  226. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  227. dtype=torch.int32,
  228. device=device)
  229. torch.cumsum(seq_lens_tensor,
  230. dim=0,
  231. dtype=seq_start_loc.dtype,
  232. out=seq_start_loc[1:])
  233. torch.cumsum(query_lens_tensor,
  234. dim=0,
  235. dtype=query_start_loc.dtype,
  236. out=query_start_loc[1:])
  237. # Placeholders
  238. slot_mapping = torch.empty(0)
  239. block_tables = torch.empty(0)
  240. return PlaceholderAttentionMetadata(
  241. num_prefills=self.num_prefills,
  242. slot_mapping=slot_mapping,
  243. num_prefill_tokens=self.num_prefill_tokens,
  244. num_decode_tokens=num_decode_tokens,
  245. seq_lens=seq_lens,
  246. seq_lens_tensor=seq_lens_tensor,
  247. max_query_len=max_query_len,
  248. max_prefill_seq_len=max_prefill_seq_len,
  249. max_decode_seq_len=max_decode_seq_len,
  250. query_start_loc=query_start_loc,
  251. seq_start_loc=seq_start_loc,
  252. context_lens_tensor=context_lens_tensor,
  253. block_tables=block_tables,
  254. use_cuda_graph=use_captured_graph,
  255. )
  256. class PlaceholderAttentionImpl(AttentionImpl):
  257. def __init__(self, *args, **kwargs) -> None:
  258. return
  259. def forward(self, *args, **kwargs) -> torch.Tensor:
  260. raise NotImplementedError