utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
  2. import numpy as np
  3. import torch
  4. from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
  5. from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
  6. # Error string(s) for encoder/decoder
  7. # unsupported attention scenarios
  8. STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
  9. "with encoder/decoder models.")
  10. PAD_SLOT_ID = -1
  11. # Switch to numpy implementation of compute_slot_mapping
  12. # if we have at least this many elements. Could be tuned further.
  13. _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
  14. if TYPE_CHECKING:
  15. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  16. def is_block_tables_empty(block_tables: Union[None, Dict]):
  17. """
  18. Check if block_tables is None or a dictionary with all None values.
  19. """
  20. if block_tables is None:
  21. return True
  22. if isinstance(block_tables, dict) and all(
  23. value is None for value in block_tables.values()):
  24. return True
  25. return False
  26. def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
  27. context_len: int, sliding_window: int,
  28. use_v2_block_manager: bool):
  29. """
  30. Compute the start index of slot mapping.
  31. """
  32. start_idx = 0
  33. if is_prompt and sliding_window is not None:
  34. assert use_v2_block_manager or context_len == 0, (
  35. "Prefix caching is currently not supported with "
  36. "sliding window attention in V1 block manager")
  37. # When prefill, we use it to not write slots to kv cache
  38. # to save memory.
  39. start_idx = max(0, query_len - sliding_window)
  40. return start_idx
  41. def _compute_slot_mapping_python(slot_mapping: List[int],
  42. block_table: List[int], range_start: int,
  43. range_end: int, block_size: int):
  44. for i in range(range_start, range_end):
  45. block_number = block_table[i // block_size]
  46. block_offset = i % block_size
  47. slot = block_number * block_size + block_offset
  48. slot_mapping.append(slot)
  49. def _compute_slot_mapping_numpy(slot_mapping: List[int],
  50. block_table: List[int], range_start: int,
  51. range_end: int, block_size: int):
  52. block_table_array = np.array(block_table)
  53. idx = np.arange(range_start, range_end)
  54. block_offset = idx % block_size
  55. idx //= block_size
  56. seq_slot_mapping_array = block_table_array[idx]
  57. seq_slot_mapping_array *= block_size
  58. seq_slot_mapping_array += block_offset
  59. slot_mapping.extend(seq_slot_mapping_array)
  60. def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
  61. seq_id: int, seq_len: int, context_len: int,
  62. start_idx: int, block_size: int,
  63. block_tables: Dict[int, List[int]]):
  64. """
  65. Compute slot mapping.
  66. """
  67. if is_profile_run:
  68. # During memory profiling, the block tables are not
  69. # initialized yet. In this case, we just use a dummy
  70. # slot mapping.
  71. # In embeddings, the block tables are {seq_id: None}.
  72. slot_mapping.extend([PAD_SLOT_ID] * seq_len)
  73. return
  74. # Mask the [0, start_idx) tokens of the prompt with
  75. # PAD_SLOT_ID, where start_idx is max(0, seq_len -
  76. # sliding_window). For example, if the prompt len is 10,
  77. # sliding window is 8, and block size is 4, the first two
  78. # tokens are masked and the slot mapping will be
  79. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  80. padding_mask_len = max(0, start_idx - context_len)
  81. slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
  82. range_start = max(start_idx, context_len)
  83. range_end = seq_len
  84. numel = range_end - range_start
  85. block_table = block_tables[seq_id]
  86. # numpy implementation will be faster than python if we have
  87. # many elements, otherwise it will be slower.
  88. if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
  89. _compute_slot_mapping_python(slot_mapping, block_table, range_start,
  90. range_end, block_size)
  91. else:
  92. _compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
  93. range_end, block_size)
  94. TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
  95. class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
  96. _metadata_cls: Type[TAttentionMetadata]
  97. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  98. self.slot_mapping: List[int] = []
  99. self.prefill_seq_lens: List[int] = []
  100. self.context_lens: List[int] = []
  101. self.block_tables: List[List[int]] = []
  102. self.curr_seq_lens: List[int] = []
  103. self.num_prefills = 0
  104. self.num_prefill_tokens = 0
  105. self.num_decode_tokens = 0
  106. self.input_builder = input_builder
  107. self.runner = input_builder.runner
  108. self.sliding_window = input_builder.sliding_window
  109. self.block_size = input_builder.block_size
  110. self.use_v2_block_manager = (
  111. input_builder.scheduler_config.use_v2_block_manager)
  112. def _add_seq_group(
  113. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  114. chunked_prefill_enabled: bool):
  115. is_prompt = inter_data.is_prompt
  116. block_tables = inter_data.block_tables
  117. computed_block_nums = inter_data.computed_block_nums
  118. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  119. curr_sliding_window_block) in zip(
  120. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  121. inter_data.orig_seq_lens, inter_data.seq_lens,
  122. inter_data.query_lens, inter_data.context_lens,
  123. inter_data.curr_sliding_window_blocks):
  124. self.context_lens.append(context_len)
  125. if is_prompt:
  126. self.num_prefills += 1
  127. self.num_prefill_tokens += token_len
  128. self.prefill_seq_lens.append(seq_len)
  129. else:
  130. assert query_len == 1, (
  131. "seq_len: {}, context_len: {}, query_len: {}".format(
  132. seq_len, context_len, query_len))
  133. self.num_decode_tokens += query_len
  134. self.curr_seq_lens.append(curr_seq_len)
  135. # Compute block table.
  136. # TODO: Combine chunked prefill and prefix caching by
  137. # only allowing multiple of block_size chunk size.
  138. # NOTE: This only works for oooooooxxx style attention.
  139. block_table = []
  140. if inter_data.prefix_cache_hit:
  141. block_table = computed_block_nums
  142. elif ((chunked_prefill_enabled or not is_prompt)
  143. and block_tables is not None):
  144. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  145. self.block_tables.append(block_table)
  146. # Compute slot mapping.
  147. is_profile_run = is_block_tables_empty(block_tables)
  148. start_idx = compute_slot_mapping_start_idx(
  149. is_prompt, query_len, context_len, self.sliding_window,
  150. self.use_v2_block_manager)
  151. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  152. seq_len, context_len, start_idx,
  153. self.block_size, inter_data.block_tables)
  154. def build(self, seq_lens: List[int], query_lens: List[int],
  155. cuda_graph_pad_size: int, batch_size: int):
  156. """Build attention metadata with on-device tensors.
  157. Args:
  158. seq_lens: The maybe padded sequence lengths of the input sequences.
  159. query_lens: The query lengths of the input sequences.
  160. cuda_graph_pad_size: The padding size for cuda graph.
  161. -1 if cuda graph is not used.
  162. batch_size: The maybe padded batch size.
  163. """
  164. for inter_data in self.input_builder.inter_data_list:
  165. self._add_seq_group(inter_data,
  166. self.input_builder.chunked_prefill_enabled)
  167. device = self.runner.device
  168. use_captured_graph = cuda_graph_pad_size != -1
  169. max_query_len = max(query_lens)
  170. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  171. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  172. num_decode_tokens = self.num_decode_tokens
  173. if use_captured_graph:
  174. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  175. self.block_tables.extend([] * cuda_graph_pad_size)
  176. num_decode_tokens = batch_size
  177. # The shape of graph_block_tables is
  178. # [max batch size, max context len // block size].
  179. input_block_tables = self.runner.graph_block_tables[:batch_size]
  180. for i, block_table in enumerate(self.block_tables):
  181. if block_table:
  182. input_block_tables[i, :len(block_table)] = block_table
  183. block_tables = torch.from_numpy(input_block_tables).to(
  184. device, non_blocking=True)
  185. else:
  186. block_tables = make_tensor_with_pad(
  187. self.block_tables,
  188. pad=0,
  189. dtype=torch.int,
  190. device=device,
  191. )
  192. assert max_query_len > 0, "query_lens: {}".format(query_lens)
  193. assert device is not None
  194. context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
  195. device, self.runner.pin_memory)
  196. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  197. self.runner.pin_memory)
  198. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  199. self.runner.pin_memory)
  200. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  201. device, self.runner.pin_memory)
  202. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  203. dtype=torch.int32,
  204. device=device)
  205. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  206. dtype=torch.int32,
  207. device=device)
  208. torch.cumsum(seq_lens_tensor,
  209. dim=0,
  210. dtype=seq_start_loc.dtype,
  211. out=seq_start_loc[1:])
  212. torch.cumsum(query_lens_tensor,
  213. dim=0,
  214. dtype=query_start_loc.dtype,
  215. out=query_start_loc[1:])
  216. return self._metadata_cls( # type: ignore
  217. num_prefills=self.num_prefills,
  218. slot_mapping=slot_mapping_tensor,
  219. num_prefill_tokens=self.num_prefill_tokens,
  220. num_decode_tokens=num_decode_tokens,
  221. seq_lens=seq_lens,
  222. seq_lens_tensor=seq_lens_tensor,
  223. max_query_len=max_query_len,
  224. max_prefill_seq_len=max_prefill_seq_len,
  225. max_decode_seq_len=max_decode_seq_len,
  226. query_start_loc=query_start_loc,
  227. seq_start_loc=seq_start_loc,
  228. context_lens_tensor=context_lens_tensor,
  229. block_tables=block_tables,
  230. use_cuda_graph=use_captured_graph,
  231. )