1
0

utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
  2. import torch
  3. from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
  4. from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
  5. # Error string(s) for encoder/decoder
  6. # unsupported attention scenarios
  7. STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
  8. "with encoder/decoder models.")
  9. PAD_SLOT_ID = -1
  10. if TYPE_CHECKING:
  11. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  12. def is_block_tables_empty(block_tables: Union[None, Dict]):
  13. """
  14. Check if block_tables is None or a dictionary with all None values.
  15. """
  16. if block_tables is None:
  17. return True
  18. if isinstance(block_tables, dict) and all(
  19. value is None for value in block_tables.values()):
  20. return True
  21. return False
  22. def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
  23. context_len: int, sliding_window: int,
  24. use_v2_block_manager: bool):
  25. """
  26. Compute the start index of slot mapping.
  27. """
  28. start_idx = 0
  29. if is_prompt and sliding_window is not None:
  30. assert use_v2_block_manager or context_len == 0, (
  31. "Prefix caching is currently not supported with "
  32. "sliding window attention in V1 block manager")
  33. # When prefill, we use it to not write slots to kv cache
  34. # to save memory.
  35. start_idx = max(0, query_len - sliding_window)
  36. return start_idx
  37. def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
  38. seq_id: int, seq_len: int, context_len: int,
  39. start_idx: int, block_size: int,
  40. block_tables: Dict[int, List[int]]):
  41. """
  42. Compute slot mapping.
  43. """
  44. if is_profile_run:
  45. # During memory profiling, the block tables are not
  46. # initialized yet. In this case, we just use a dummy
  47. # slot mapping.
  48. # In embeddings, the block tables are {seq_id: None}.
  49. slot_mapping.extend([PAD_SLOT_ID] * seq_len)
  50. return
  51. # Mask the [0, start_idx) tokens of the prompt with
  52. # PAD_SLOT_ID, where start_idx is max(0, seq_len -
  53. # sliding_window). For example, if the prompt len is 10,
  54. # sliding window is 8, and block size is 4, the first two
  55. # tokens are masked and the slot mapping will be
  56. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  57. block_table = block_tables[seq_id]
  58. def add_slot(i):
  59. block_number = block_table[i // block_size]
  60. block_offset = i % block_size
  61. slot = block_number * block_size + block_offset
  62. slot_mapping.append(slot)
  63. if start_idx == 0 and (seq_len - context_len) == 1:
  64. # Optimization for common-case of decoding next token
  65. add_slot(seq_len - 1)
  66. else:
  67. slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
  68. for i in range(max(start_idx, context_len), seq_len):
  69. add_slot(i)
  70. TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
  71. class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
  72. _metadata_cls: Type[TAttentionMetadata]
  73. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  74. self.slot_mapping: List[int] = []
  75. self.prefill_seq_lens: List[int] = []
  76. self.context_lens: List[int] = []
  77. self.block_tables: List[List[int]] = []
  78. self.curr_seq_lens: List[int] = []
  79. self.num_prefills = 0
  80. self.num_prefill_tokens = 0
  81. self.num_decode_tokens = 0
  82. self.input_builder = input_builder
  83. self.runner = input_builder.runner
  84. self.sliding_window = input_builder.sliding_window
  85. self.block_size = input_builder.block_size
  86. self.use_v2_block_manager = (
  87. input_builder.scheduler_config.use_v2_block_manager)
  88. def _add_seq_group(
  89. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  90. chunked_prefill_enabled: bool):
  91. is_prompt = inter_data.is_prompt
  92. block_tables = inter_data.block_tables
  93. computed_block_nums = inter_data.computed_block_nums
  94. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  95. curr_sliding_window_block) in zip(
  96. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  97. inter_data.orig_seq_lens, inter_data.seq_lens,
  98. inter_data.query_lens, inter_data.context_lens,
  99. inter_data.curr_sliding_window_blocks):
  100. self.context_lens.append(context_len)
  101. if is_prompt:
  102. self.num_prefills += 1
  103. self.num_prefill_tokens += token_len
  104. self.prefill_seq_lens.append(seq_len)
  105. else:
  106. assert query_len == 1, (
  107. "seq_len: {}, context_len: {}, query_len: {}".format(
  108. seq_len, context_len, query_len))
  109. self.num_decode_tokens += query_len
  110. self.curr_seq_lens.append(curr_seq_len)
  111. # Compute block table.
  112. # TODO: Combine chunked prefill and prefix caching by
  113. # only allowing multiple of block_size chunk size.
  114. # NOTE: This only works for oooooooxxx style attention.
  115. block_table = []
  116. if inter_data.prefix_cache_hit:
  117. block_table = computed_block_nums
  118. elif ((chunked_prefill_enabled or not is_prompt)
  119. and block_tables is not None):
  120. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  121. self.block_tables.append(block_table)
  122. # Compute slot mapping.
  123. is_profile_run = is_block_tables_empty(block_tables)
  124. start_idx = compute_slot_mapping_start_idx(
  125. is_prompt, query_len, context_len, self.sliding_window,
  126. self.use_v2_block_manager)
  127. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  128. seq_len, context_len, start_idx,
  129. self.block_size, inter_data.block_tables)
  130. def build(self, seq_lens: List[int], query_lens: List[int],
  131. cuda_graph_pad_size: int, batch_size: int):
  132. """Build attention metadata with on-device tensors.
  133. Args:
  134. seq_lens: The maybe padded sequence lengths of the input sequences.
  135. query_lens: The query lengths of the input sequences.
  136. cuda_graph_pad_size: The padding size for cuda graph.
  137. -1 if cuda graph is not used.
  138. batch_size: The maybe padded batch size.
  139. """
  140. for inter_data in self.input_builder.inter_data_list:
  141. self._add_seq_group(inter_data,
  142. self.input_builder.chunked_prefill_enabled)
  143. device = self.runner.device
  144. use_captured_graph = cuda_graph_pad_size != -1
  145. max_query_len = max(query_lens)
  146. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  147. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  148. num_decode_tokens = self.num_decode_tokens
  149. if use_captured_graph:
  150. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  151. self.block_tables.extend([] * cuda_graph_pad_size)
  152. num_decode_tokens = batch_size
  153. # The shape of graph_block_tables is
  154. # [max batch size, max context len // block size].
  155. input_block_tables = self.runner.graph_block_tables[:batch_size]
  156. for i, block_table in enumerate(self.block_tables):
  157. if block_table:
  158. input_block_tables[i, :len(block_table)] = block_table
  159. block_tables = torch.from_numpy(input_block_tables).to(
  160. device, non_blocking=True)
  161. else:
  162. block_tables = make_tensor_with_pad(
  163. self.block_tables,
  164. pad=0,
  165. dtype=torch.int,
  166. device=device,
  167. )
  168. assert max_query_len > 0, "query_lens: {}".format(query_lens)
  169. assert device is not None
  170. context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
  171. device, self.runner.pin_memory)
  172. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  173. self.runner.pin_memory)
  174. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  175. self.runner.pin_memory)
  176. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  177. device, self.runner.pin_memory)
  178. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  179. dtype=torch.int32,
  180. device=device)
  181. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  182. dtype=torch.int32,
  183. device=device)
  184. torch.cumsum(seq_lens_tensor,
  185. dim=0,
  186. dtype=seq_start_loc.dtype,
  187. out=seq_start_loc[1:])
  188. torch.cumsum(query_lens_tensor,
  189. dim=0,
  190. dtype=query_start_loc.dtype,
  191. out=query_start_loc[1:])
  192. return self._metadata_cls( # type: ignore
  193. num_prefills=self.num_prefills,
  194. slot_mapping=slot_mapping_tensor,
  195. num_prefill_tokens=self.num_prefill_tokens,
  196. num_decode_tokens=num_decode_tokens,
  197. seq_lens=seq_lens,
  198. seq_lens_tensor=seq_lens_tensor,
  199. max_query_len=max_query_len,
  200. max_prefill_seq_len=max_prefill_seq_len,
  201. max_decode_seq_len=max_decode_seq_len,
  202. query_start_loc=query_start_loc,
  203. seq_start_loc=seq_start_loc,
  204. context_lens_tensor=context_lens_tensor,
  205. block_tables=block_tables,
  206. use_cuda_graph=use_captured_graph,
  207. )