utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from contextlib import contextmanager
  2. from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
  3. import numpy as np
  4. import torch
  5. from aphrodite.attention import (AttentionMetadata, AttentionMetadataBuilder,
  6. AttentionState)
  7. from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
  8. if TYPE_CHECKING:
  9. from aphrodite.task_handler.model_runner_base import ModelRunnerBase
  10. # Error string(s) for encoder/decoder
  11. # unsupported attention scenarios
  12. STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
  13. "with encoder/decoder models.")
  14. PAD_SLOT_ID = -1
  15. # Switch to numpy implementation of compute_slot_mapping
  16. # if we have at least this many elements. Could be tuned further.
  17. _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
  18. if TYPE_CHECKING:
  19. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  20. def is_block_tables_empty(block_tables: Union[None, Dict]):
  21. """
  22. Check if block_tables is None or a dictionary with all None values.
  23. """
  24. if block_tables is None:
  25. return True
  26. if isinstance(block_tables, dict) and all(
  27. value is None for value in block_tables.values()):
  28. return True
  29. return False
  30. def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
  31. context_len: int, sliding_window: int,
  32. use_v2_block_manager: bool):
  33. """
  34. Compute the start index of slot mapping.
  35. """
  36. start_idx = 0
  37. if is_prompt and sliding_window is not None:
  38. assert use_v2_block_manager or context_len == 0, (
  39. "Prefix caching is currently not supported with "
  40. "sliding window attention in V1 block manager")
  41. # When prefill, we use it to not write slots to kv cache
  42. # to save memory.
  43. start_idx = max(0, query_len - sliding_window)
  44. return start_idx
  45. def _compute_slot_mapping_python(slot_mapping: List[int],
  46. block_table: List[int], range_start: int,
  47. range_end: int, block_size: int):
  48. for i in range(range_start, range_end):
  49. block_number = block_table[i // block_size]
  50. block_offset = i % block_size
  51. slot = block_number * block_size + block_offset
  52. slot_mapping.append(slot)
  53. def _compute_slot_mapping_numpy(slot_mapping: List[int],
  54. block_table: List[int], range_start: int,
  55. range_end: int, block_size: int):
  56. block_table_array = np.array(block_table)
  57. idx = np.arange(range_start, range_end)
  58. block_offset = idx % block_size
  59. idx //= block_size
  60. seq_slot_mapping_array = block_table_array[idx]
  61. seq_slot_mapping_array *= block_size
  62. seq_slot_mapping_array += block_offset
  63. slot_mapping.extend(seq_slot_mapping_array)
  64. def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
  65. seq_id: int, seq_len: int, context_len: int,
  66. start_idx: int, block_size: int,
  67. block_tables: Dict[int, List[int]]):
  68. """
  69. Compute slot mapping.
  70. """
  71. if is_profile_run:
  72. # During memory profiling, the block tables are not
  73. # initialized yet. In this case, we just use a dummy
  74. # slot mapping.
  75. # In embeddings, the block tables are {seq_id: None}.
  76. slot_mapping.extend([PAD_SLOT_ID] * seq_len)
  77. return
  78. # Mask the [0, start_idx) tokens of the prompt with
  79. # PAD_SLOT_ID, where start_idx is max(0, seq_len -
  80. # sliding_window). For example, if the prompt len is 10,
  81. # sliding window is 8, and block size is 4, the first two
  82. # tokens are masked and the slot mapping will be
  83. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  84. padding_mask_len = max(0, start_idx - context_len)
  85. slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
  86. range_start = max(start_idx, context_len)
  87. range_end = seq_len
  88. numel = range_end - range_start
  89. block_table = block_tables[seq_id]
  90. # numpy implementation will be faster than python if we have
  91. # many elements, otherwise it will be slower.
  92. if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
  93. _compute_slot_mapping_python(slot_mapping, block_table, range_start,
  94. range_end, block_size)
  95. else:
  96. _compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
  97. range_end, block_size)
  98. TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
  99. class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
  100. _metadata_cls: Type[TAttentionMetadata]
  101. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  102. self.slot_mapping: List[int] = []
  103. self.prefill_seq_lens: List[int] = []
  104. self.context_lens: List[int] = []
  105. self.block_tables: List[List[int]] = []
  106. self.curr_seq_lens: List[int] = []
  107. self.num_prefills = 0
  108. self.num_prefill_tokens = 0
  109. self.num_decode_tokens = 0
  110. self.input_builder = input_builder
  111. self.runner = input_builder.runner
  112. self.sliding_window = input_builder.sliding_window
  113. self.block_size = input_builder.block_size
  114. self.use_v2_block_manager = (
  115. input_builder.scheduler_config.use_v2_block_manager)
  116. def _add_seq_group(
  117. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  118. chunked_prefill_enabled: bool):
  119. is_prompt = inter_data.is_prompt
  120. block_tables = inter_data.block_tables
  121. computed_block_nums = inter_data.computed_block_nums
  122. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  123. curr_sliding_window_block) in zip(
  124. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  125. inter_data.orig_seq_lens, inter_data.seq_lens,
  126. inter_data.query_lens, inter_data.context_lens,
  127. inter_data.curr_sliding_window_blocks):
  128. self.context_lens.append(context_len)
  129. if is_prompt:
  130. self.num_prefills += 1
  131. self.num_prefill_tokens += token_len
  132. self.prefill_seq_lens.append(seq_len)
  133. else:
  134. assert query_len == 1, (
  135. "seq_len: {}, context_len: {}, query_len: {}".format(
  136. seq_len, context_len, query_len))
  137. self.num_decode_tokens += query_len
  138. self.curr_seq_lens.append(curr_seq_len)
  139. # Compute block table.
  140. # TODO: Combine chunked prefill and prefix caching by
  141. # only allowing multiple of block_size chunk size.
  142. # NOTE: This only works for oooooooxxx style attention.
  143. block_table = []
  144. if inter_data.prefix_cache_hit:
  145. block_table = computed_block_nums
  146. elif ((chunked_prefill_enabled or not is_prompt)
  147. and block_tables is not None):
  148. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  149. self.block_tables.append(block_table)
  150. # Compute slot mapping.
  151. is_profile_run = is_block_tables_empty(block_tables)
  152. start_idx = compute_slot_mapping_start_idx(
  153. is_prompt, query_len, context_len, self.sliding_window,
  154. self.use_v2_block_manager)
  155. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  156. seq_len, context_len, start_idx,
  157. self.block_size, inter_data.block_tables)
  158. def build(self, seq_lens: List[int], query_lens: List[int],
  159. cuda_graph_pad_size: int, batch_size: int):
  160. """Build attention metadata with on-device tensors.
  161. Args:
  162. seq_lens: The maybe padded sequence lengths of the input sequences.
  163. query_lens: The query lengths of the input sequences.
  164. cuda_graph_pad_size: The padding size for cuda graph.
  165. -1 if cuda graph is not used.
  166. batch_size: The maybe padded batch size.
  167. """
  168. for inter_data in self.input_builder.inter_data_list:
  169. self._add_seq_group(inter_data,
  170. self.input_builder.chunked_prefill_enabled)
  171. device = self.runner.device
  172. use_captured_graph = cuda_graph_pad_size != -1
  173. max_query_len = max(query_lens)
  174. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  175. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  176. num_decode_tokens = self.num_decode_tokens
  177. if use_captured_graph:
  178. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  179. self.block_tables.extend([] * cuda_graph_pad_size)
  180. num_decode_tokens = batch_size
  181. # The shape of graph_block_tables is
  182. # [max batch size, max context len // block size].
  183. input_block_tables = self.runner.graph_block_tables[:batch_size]
  184. for i, block_table in enumerate(self.block_tables):
  185. if block_table:
  186. input_block_tables[i, :len(block_table)] = block_table
  187. block_tables = torch.from_numpy(input_block_tables).to(
  188. device, non_blocking=True)
  189. else:
  190. block_tables = make_tensor_with_pad(
  191. self.block_tables,
  192. pad=0,
  193. dtype=torch.int,
  194. device=device,
  195. )
  196. assert max_query_len > 0, "query_lens: {}".format(query_lens)
  197. assert device is not None
  198. context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
  199. device, self.runner.pin_memory)
  200. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  201. self.runner.pin_memory)
  202. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  203. self.runner.pin_memory)
  204. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  205. device, self.runner.pin_memory)
  206. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  207. dtype=torch.int32,
  208. device=device)
  209. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  210. dtype=torch.int32,
  211. device=device)
  212. torch.cumsum(seq_lens_tensor,
  213. dim=0,
  214. dtype=seq_start_loc.dtype,
  215. out=seq_start_loc[1:])
  216. torch.cumsum(query_lens_tensor,
  217. dim=0,
  218. dtype=query_start_loc.dtype,
  219. out=query_start_loc[1:])
  220. return self._metadata_cls( # type: ignore
  221. num_prefills=self.num_prefills,
  222. slot_mapping=slot_mapping_tensor,
  223. num_prefill_tokens=self.num_prefill_tokens,
  224. num_decode_tokens=num_decode_tokens,
  225. seq_lens=seq_lens,
  226. seq_lens_tensor=seq_lens_tensor,
  227. max_query_len=max_query_len,
  228. max_prefill_seq_len=max_prefill_seq_len,
  229. max_decode_seq_len=max_decode_seq_len,
  230. query_start_loc=query_start_loc,
  231. seq_start_loc=seq_start_loc,
  232. context_lens_tensor=context_lens_tensor,
  233. block_tables=block_tables,
  234. use_cuda_graph=use_captured_graph,
  235. )
  236. class CommonAttentionState(AttentionState):
  237. def __init__(self, runner: "ModelRunnerBase"):
  238. self.runner = runner
  239. self._is_graph_capturing = False
  240. @contextmanager
  241. def graph_capture(self, max_batch_size: int):
  242. self._is_graph_capturing = True
  243. self._graph_slot_mapping = torch.full((max_batch_size, ),
  244. PAD_SLOT_ID,
  245. dtype=torch.long,
  246. device=self.runner.device)
  247. self._graph_seq_lens = torch.ones(max_batch_size,
  248. dtype=torch.int32,
  249. device=self.runner.device)
  250. self._graph_block_tables = torch.from_numpy(
  251. self.runner.graph_block_tables).to(device=self.runner.device)
  252. yield
  253. self._is_graph_capturing = False
  254. del self._graph_slot_mapping
  255. del self._graph_seq_lens
  256. del self._graph_block_tables
  257. def graph_clone(self, batch_size: int) -> "CommonAttentionState":
  258. assert self._is_graph_capturing
  259. return self.__class__(self.runner)
  260. def graph_capture_get_metadata_for_batch(self, batch_size: int):
  261. assert self._is_graph_capturing
  262. attn_metadata = self.runner.attn_backend.make_metadata(
  263. num_prefills=0,
  264. num_prefill_tokens=0,
  265. num_decode_tokens=batch_size,
  266. slot_mapping=self._graph_slot_mapping[:batch_size],
  267. seq_lens=None,
  268. seq_lens_tensor=self._graph_seq_lens[:batch_size],
  269. max_query_len=None,
  270. max_prefill_seq_len=0,
  271. max_decode_seq_len=self.runner.max_seq_len_to_capture,
  272. query_start_loc=None,
  273. seq_start_loc=None,
  274. context_lens_tensor=None,
  275. block_tables=self._graph_block_tables[:batch_size],
  276. use_cuda_graph=True,
  277. )
  278. return attn_metadata
  279. def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
  280. return {
  281. "slot_mapping": attn_metadata.slot_mapping,
  282. "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
  283. "block_tables": attn_metadata.decode_metadata.block_tables,
  284. }
  285. def prepare_graph_input_buffers(self, input_buffers,
  286. attn_metadata) -> None:
  287. input_buffers["seq_lens_tensor"].copy_(
  288. attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
  289. input_buffers["block_tables"].copy_(
  290. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  291. def begin_forward(self, model_input) -> None:
  292. return