placeholder_attn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.attention.backends.abstract import (AttentionBackend,
  5. AttentionImpl,
  6. AttentionMetadata,
  7. AttentionMetadataBuilder)
  8. if TYPE_CHECKING:
  9. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  10. # Placeholder attention backend for models like Mamba and embedding models that
  11. # lack attention.
  12. class PlaceholderAttentionBackend(AttentionBackend):
  13. """Placeholder backend for when no attention is needed."""
  14. @staticmethod
  15. def get_name() -> str:
  16. return "No attention"
  17. @staticmethod
  18. def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
  19. return PlaceholderAttentionImpl
  20. @staticmethod
  21. def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
  22. return PlaceholderAttentionMetadataBuilder
  23. @staticmethod
  24. def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
  25. return PlaceholderAttentionMetadata
  26. @staticmethod
  27. def get_kv_cache_shape(
  28. num_blocks: int,
  29. block_size: int,
  30. num_kv_heads: int,
  31. head_size: int,
  32. ) -> Tuple[int, ...]:
  33. return (1, 1, 1, 1, 1)
  34. @staticmethod
  35. def swap_blocks(
  36. src_kv_cache: torch.Tensor,
  37. dst_kv_cache: torch.Tensor,
  38. src_to_dst: torch.Tensor,
  39. ) -> None:
  40. return
  41. @staticmethod
  42. def copy_blocks(
  43. kv_caches: List[torch.Tensor],
  44. src_to_dists: torch.Tensor,
  45. ) -> None:
  46. return
  47. @dataclass
  48. class PlaceholderAttentionMetadata(AttentionMetadata):
  49. """Attention metadata for prefill and decode batched together."""
  50. # (batch_size,). The sequence length per sequence. Sequence length means
  51. # the computed tokens + new tokens None if it is a decoding.
  52. seq_lens: Optional[List[int]]
  53. # seq_lens stored as a tensor.
  54. seq_lens_tensor: Optional[torch.Tensor]
  55. # Maximum query length in the batch. None for decoding.
  56. max_query_len: Optional[int]
  57. # Maximum sequence length among prefill batch. 0 if there are decoding
  58. # requests only.
  59. max_prefill_seq_len: int
  60. # Maximum sequence length among decode batch. 0 if there are prefill
  61. # requests only.
  62. max_decode_seq_len: int
  63. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  64. # the batch, used to index into subquery. E.g., if the subquery length
  65. # is [4, 6], it is [0, 4, 10].
  66. query_start_loc: Optional[torch.Tensor]
  67. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  68. # the batch, used to index into sequence. E.g., if the sequence length is
  69. # [4, 6], it is [0, 4, 10].
  70. seq_start_loc: Optional[torch.Tensor]
  71. # (batch_size,) A tensor of context lengths (tokens that are computed
  72. # so far).
  73. context_lens_tensor: Optional[torch.Tensor]
  74. # (batch_size, max_blocks_per_seq).
  75. # Block addresses per sequence. (Seq id -> list of physical block)
  76. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  77. # in the kv cache. Each block can contain up to block_size tokens.
  78. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  79. # captured.
  80. block_tables: Optional[torch.Tensor]
  81. # Whether or not if cuda graph is enabled.
  82. # Cuda-graph is currently enabled for decoding only.
  83. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
  84. use_cuda_graph: bool
  85. _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
  86. _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
  87. @property
  88. def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
  89. if self.num_prefills == 0:
  90. return None
  91. if self._cached_prefill_metadata is not None:
  92. return self._cached_prefill_metadata
  93. assert self.seq_lens is not None
  94. assert self.seq_lens_tensor is not None
  95. assert self.query_start_loc is not None
  96. assert self.context_lens_tensor is not None
  97. assert self.seq_start_loc is not None
  98. # Placeholders
  99. slot_mapping = torch.empty(0)
  100. block_tables = torch.empty(0)
  101. self._cached_prefill_metadata = PlaceholderAttentionMetadata(
  102. num_prefills=self.num_prefills,
  103. num_prefill_tokens=self.num_prefill_tokens,
  104. num_decode_tokens=0,
  105. slot_mapping=slot_mapping,
  106. seq_lens=self.seq_lens[:self.num_prefills],
  107. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  108. max_query_len=self.max_query_len,
  109. max_prefill_seq_len=self.max_prefill_seq_len,
  110. max_decode_seq_len=0,
  111. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  112. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  113. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  114. block_tables=block_tables,
  115. use_cuda_graph=False,
  116. )
  117. return self._cached_prefill_metadata
  118. @property
  119. def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
  120. if self.num_decode_tokens == 0:
  121. return None
  122. if self._cached_decode_metadata is not None:
  123. return self._cached_decode_metadata
  124. assert self.seq_lens_tensor is not None
  125. # Placeholders
  126. slot_mapping = torch.empty(0)
  127. block_tables = torch.empty(0)
  128. self._cached_decode_metadata = PlaceholderAttentionMetadata(
  129. num_prefills=0,
  130. num_prefill_tokens=0,
  131. num_decode_tokens=self.num_decode_tokens,
  132. slot_mapping=slot_mapping,
  133. seq_lens=None,
  134. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  135. max_query_len=None,
  136. max_prefill_seq_len=0,
  137. max_decode_seq_len=self.max_decode_seq_len,
  138. query_start_loc=None,
  139. seq_start_loc=None,
  140. context_lens_tensor=None,
  141. block_tables=block_tables,
  142. use_cuda_graph=self.use_cuda_graph,
  143. )
  144. return self._cached_decode_metadata
  145. class PlaceholderAttentionMetadataBuilder(
  146. AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
  147. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  148. self.prefill_seq_lens: List[int] = []
  149. self.context_lens: List[int] = []
  150. self.curr_seq_lens: List[int] = []
  151. self.num_prefills = 0
  152. self.num_prefill_tokens = 0
  153. self.num_decode_tokens = 0
  154. self.input_builder = input_builder
  155. self.runner = input_builder.runner
  156. def _add_seq_group(
  157. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  158. chunked_prefill_enabled: bool):
  159. """Add a sequence group to the metadata. Specifically update/append
  160. 1. context length.
  161. """
  162. is_prompt = inter_data.is_prompt
  163. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  164. curr_sliding_window_block) in zip(
  165. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  166. inter_data.orig_seq_lens, inter_data.seq_lens,
  167. inter_data.query_lens, inter_data.context_lens,
  168. inter_data.curr_sliding_window_blocks):
  169. self.context_lens.append(context_len)
  170. if is_prompt:
  171. self.num_prefills += 1
  172. self.num_prefill_tokens += token_len
  173. self.prefill_seq_lens.append(seq_len)
  174. else:
  175. assert query_len == 1, (
  176. "seq_len: {}, context_len: {}, query_len: {}".format(
  177. seq_len, context_len, query_len))
  178. self.num_decode_tokens += query_len
  179. self.curr_seq_lens.append(curr_seq_len)
  180. def build(self, seq_lens: List[int], query_lens: List[int],
  181. cuda_graph_pad_size: int, batch_size: int):
  182. """Build attention metadata with on-device tensors.
  183. Args:
  184. seq_lens: The maybe padded sequence lengths of the input sequences.
  185. query_lens: The query lengths of the input sequences.
  186. cuda_graph_pad_size: The padding size for cuda graph.
  187. -1 if cuda graph is not used.
  188. batch_size: The maybe padded batch size.
  189. """
  190. for inter_data in self.input_builder.inter_data_list:
  191. self._add_seq_group(inter_data,
  192. self.input_builder.chunked_prefill_enabled)
  193. device = self.runner.device
  194. use_captured_graph = cuda_graph_pad_size != -1
  195. logits_soft_cap = getattr(self.runner.model_config.hf_config,
  196. "attn_logit_softcapping", None)
  197. if logits_soft_cap is not None:
  198. raise ValueError(
  199. "Please use Flashinfer backend for models with logits_soft_cap"
  200. " (i.e., Gemma-2). Otherwise, the output might be wrong."
  201. " Set Flashinfer backend by "
  202. "export VLLM_ATTENTION_BACKEND=FLASHINFER.")
  203. max_query_len = max(query_lens)
  204. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  205. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  206. num_decode_tokens = self.num_decode_tokens
  207. if use_captured_graph:
  208. num_decode_tokens = batch_size
  209. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  210. context_lens_tensor = torch.tensor(self.context_lens,
  211. dtype=torch.int,
  212. device=device)
  213. seq_lens_tensor = torch.tensor(seq_lens,
  214. dtype=torch.int,
  215. device=device)
  216. query_lens_tensor = torch.tensor(query_lens,
  217. dtype=torch.long,
  218. device=device)
  219. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  220. dtype=torch.int32,
  221. device=device)
  222. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  223. dtype=torch.int32,
  224. device=device)
  225. torch.cumsum(seq_lens_tensor,
  226. dim=0,
  227. dtype=seq_start_loc.dtype,
  228. out=seq_start_loc[1:])
  229. torch.cumsum(query_lens_tensor,
  230. dim=0,
  231. dtype=query_start_loc.dtype,
  232. out=query_start_loc[1:])
  233. # Placeholders
  234. slot_mapping = torch.empty(0)
  235. block_tables = torch.empty(0)
  236. return PlaceholderAttentionMetadata(
  237. num_prefills=self.num_prefills,
  238. slot_mapping=slot_mapping,
  239. num_prefill_tokens=self.num_prefill_tokens,
  240. num_decode_tokens=num_decode_tokens,
  241. seq_lens=seq_lens,
  242. seq_lens_tensor=seq_lens_tensor,
  243. max_query_len=max_query_len,
  244. max_prefill_seq_len=max_prefill_seq_len,
  245. max_decode_seq_len=max_decode_seq_len,
  246. query_start_loc=query_start_loc,
  247. seq_start_loc=seq_start_loc,
  248. context_lens_tensor=context_lens_tensor,
  249. block_tables=block_tables,
  250. use_cuda_graph=use_captured_graph,
  251. )
  252. class PlaceholderAttentionImpl(AttentionImpl):
  253. def __init__(self, *args, **kwargs) -> None:
  254. return
  255. def forward(self, *args, **kwargs) -> torch.Tensor:
  256. raise NotImplementedError