1
0

enc_dec_model_runner.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. import dataclasses
  2. import itertools
  3. from typing import Any, Dict, List, Optional, Tuple, Type, cast
  4. import torch
  5. import torch.distributed
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionMetadata)
  9. from aphrodite.attention.backends.utils import PAD_SLOT_ID
  10. from aphrodite.attention.selector import (_Backend,
  11. get_env_variable_attn_backend,
  12. get_global_forced_attn_backend,
  13. global_force_attn_backend)
  14. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  15. LoRAConfig, ModelConfig, ParallelConfig,
  16. PromptAdapterConfig, SchedulerConfig)
  17. from aphrodite.common.sampling_params import SamplingParams
  18. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  19. SequenceGroupMetadata)
  20. from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
  21. make_tensor_with_pad)
  22. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  23. from aphrodite.modeling.layers.sampler import SamplerOutput
  24. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  25. from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
  26. from aphrodite.worker.model_runner import (
  27. GPUModelRunnerBase, ModelInputForGPUBuilder,
  28. ModelInputForGPUWithSamplingMetadata, _get_graph_batch_size)
  29. from aphrodite.worker.model_runner_base import (
  30. _add_attn_metadata_broadcastable_dict,
  31. _add_sampling_metadata_broadcastable_dict)
  32. from aphrodite.worker.utils import assert_enc_dec_mr_supported_scenario
  33. @dataclasses.dataclass(frozen=True)
  34. class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
  35. """
  36. Used by the EncoderDecoderModelRunner.
  37. """
  38. encoder_input_tokens: Optional[torch.Tensor] = None
  39. encoder_input_positions: Optional[torch.Tensor] = None
  40. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  41. tensor_dict = {
  42. "input_tokens": self.input_tokens,
  43. "input_positions": self.input_positions,
  44. "encoder_input_tokens": self.encoder_input_tokens,
  45. "encoder_input_positions": self.encoder_input_positions,
  46. "virtual_engine": self.virtual_engine,
  47. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  48. "finished_requests_ids": self.finished_requests_ids,
  49. }
  50. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  51. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  52. self.sampling_metadata)
  53. return tensor_dict
  54. @classmethod
  55. def from_broadcasted_tensor_dict(
  56. cls,
  57. tensor_dict: Dict[str, Any],
  58. attn_backend: Optional["AttentionBackend"] = None,
  59. ) -> "EncoderDecoderModelInput":
  60. return cast(
  61. EncoderDecoderModelInput,
  62. super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
  63. class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
  64. _model_input_cls: Type[EncoderDecoderModelInput] = (
  65. EncoderDecoderModelInput)
  66. _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
  67. def __init__(
  68. self,
  69. model_config: ModelConfig,
  70. parallel_config: ParallelConfig,
  71. scheduler_config: SchedulerConfig,
  72. device_config: DeviceConfig,
  73. cache_config: CacheConfig,
  74. load_config: LoadConfig,
  75. lora_config: Optional[LoRAConfig],
  76. kv_cache_dtype: Optional[str] = "auto",
  77. is_driver_worker: bool = False,
  78. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  79. input_registry: InputRegistry = INPUT_REGISTRY,
  80. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  81. **kwargs,
  82. ):
  83. '''
  84. EncoderDecoderModelRunner constructor.
  85. `lora_config` and `prompt_adapter_config` are
  86. unused (since these features are not yet supported for encoder/decoder
  87. models) but these arguments are present here for compatibility with
  88. the base-class constructor.
  89. '''
  90. self._maybe_force_supported_attention_backend()
  91. super().__init__(
  92. model_config,
  93. parallel_config,
  94. scheduler_config,
  95. device_config,
  96. cache_config,
  97. load_config,
  98. lora_config=None,
  99. kv_cache_dtype=kv_cache_dtype,
  100. is_driver_worker=is_driver_worker,
  101. **kwargs,
  102. )
  103. # Crash for unsupported encoder/scenarios
  104. assert_enc_dec_mr_supported_scenario(self)
  105. def _maybe_force_supported_attention_backend(self):
  106. '''
  107. Force Aphrodite to use the XFormers attention backend,
  108. which is currently the only supported option.
  109. '''
  110. def raise_backend_err():
  111. # The user has specified an attention backend override
  112. # which is invalid for encoder/decoder models
  113. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
  114. maybe_env_var_forced_backend = get_env_variable_attn_backend()
  115. maybe_global_forced_backend = get_global_forced_attn_backend()
  116. is_forced_by_global = maybe_global_forced_backend is not None
  117. is_forced_by_env_var = maybe_env_var_forced_backend is not None
  118. if not (is_forced_by_global or is_forced_by_env_var):
  119. # The user has not already specified an attention backend
  120. # override
  121. logger.info("EncoderDecoderModelRunner requires "
  122. "XFormers backend; overriding backend "
  123. "auto-selection and forcing XFormers.")
  124. global_force_attn_backend(_Backend.XFORMERS)
  125. elif is_forced_by_global:
  126. # Backend override enforced by global variable takes
  127. # precedence over Aphrodite backend environment variable.
  128. if maybe_global_forced_backend != _Backend.XFORMERS:
  129. raise_backend_err()
  130. elif is_forced_by_env_var:
  131. # Backend override enforced by Aphrodite backend
  132. # environment variable
  133. if maybe_env_var_forced_backend != _Backend.XFORMERS:
  134. raise_backend_err()
  135. def _list_to_int32_tensor(
  136. self,
  137. _list: List[int],
  138. ) -> torch.Tensor:
  139. return torch.tensor(_list, dtype=torch.int32, device=self.device)
  140. def _list_to_long_tensor(
  141. self,
  142. _list: List[int],
  143. ) -> torch.Tensor:
  144. return torch.tensor(_list, dtype=torch.long, device=self.device)
  145. def _empty_int32_tensor(self) -> torch.Tensor:
  146. return self._list_to_int32_tensor([])
  147. def _empty_long_tensor(self) -> torch.Tensor:
  148. return self._list_to_long_tensor([])
  149. @torch.inference_mode()
  150. def execute_model(
  151. self,
  152. model_input: EncoderDecoderModelInput,
  153. kv_caches: List[torch.Tensor],
  154. intermediate_tensors: Optional[IntermediateTensors] = None,
  155. num_steps: int = 1,
  156. ) -> Optional[List[PoolerOutput]]:
  157. if num_steps > 1:
  158. raise ValueError("num_steps > 1 is not supported in "
  159. "EncoderDecoderModelRunner")
  160. if (model_input.attn_metadata is not None
  161. and model_input.attn_metadata.prefill_metadata is None
  162. and model_input.attn_metadata.decode_metadata.use_cuda_graph):
  163. assert model_input.input_tokens is not None
  164. graph_batch_size = model_input.input_tokens.shape[0]
  165. model_executable = self.graph_runners[
  166. model_input.virtual_engine][graph_batch_size]
  167. else:
  168. model_executable = self.model
  169. seqlen_agnostic_kwargs = {
  170. "finished_requests_ids": model_input.finished_requests_ids,
  171. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  172. } if self.has_seqlen_agnostic else {}
  173. hidden_or_intermediate_states = model_executable(
  174. input_ids=model_input.input_tokens,
  175. positions=model_input.input_positions,
  176. encoder_input_ids=model_input.encoder_input_tokens,
  177. encoder_positions=model_input.encoder_input_positions,
  178. kv_caches=kv_caches,
  179. attn_metadata=model_input.attn_metadata,
  180. intermediate_tensors=intermediate_tensors,
  181. **seqlen_agnostic_kwargs)
  182. logits = self.model.compute_logits(hidden_or_intermediate_states,
  183. model_input.sampling_metadata)
  184. if not self.is_driver_worker:
  185. return []
  186. if model_input.async_callback is not None:
  187. model_input.async_callback()
  188. # Sample the next token.
  189. output: SamplerOutput = self.model.sample(
  190. logits=logits,
  191. sampling_metadata=model_input.sampling_metadata,
  192. )
  193. return [output]
  194. def make_model_input_from_broadcasted_tensor_dict(
  195. self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
  196. return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
  197. tensor_dict,
  198. attn_backend=self.attn_backend,
  199. )
  200. def prepare_model_input(
  201. self,
  202. seq_group_metadata_list: List[SequenceGroupMetadata],
  203. virtual_engine: int = 0,
  204. finished_requests_ids: Optional[List[str]] = None
  205. ) -> EncoderDecoderModelInput:
  206. """Prepare the model input based on a given sequence group, including
  207. metadata for the sampling step.
  208. Since chunked prefill is not supported for encoder/decoder models,
  209. `input_tokens` is assumed to be either entirely prefill tokens or
  210. entirely decode tokens.
  211. """
  212. model_input = self._prepare_model_input_tensors(
  213. seq_group_metadata_list, finished_requests_ids)
  214. (
  215. attn_metadata,
  216. encoder_input_tokens_tensor,
  217. encoder_input_positions_tensor,
  218. ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
  219. model_input))
  220. # Inject attn_metadata encoder/cross-attention fields &
  221. # encoder input tokens/positions into model_input.
  222. # Frozen dataclass fields cannot be modified, so use
  223. # dataclasses.replace to construct a new model input
  224. # instance.
  225. model_input = dataclasses.replace(
  226. model_input,
  227. attn_metadata=attn_metadata,
  228. encoder_input_tokens=encoder_input_tokens_tensor,
  229. encoder_input_positions=encoder_input_positions_tensor,
  230. )
  231. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  232. model_input.seq_lens,
  233. model_input.query_lens,
  234. self.device,
  235. self.pin_memory)
  236. is_prompt = (seq_group_metadata_list[0].is_prompt
  237. if seq_group_metadata_list else None)
  238. return dataclasses.replace(model_input,
  239. sampling_metadata=sampling_metadata,
  240. is_prompt=is_prompt,
  241. virtual_engine=virtual_engine)
  242. @torch.inference_mode()
  243. def profile_run(self) -> None:
  244. # Enable top-k sampling to reflect the accurate memory usage.
  245. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  246. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  247. max_num_seqs = self.scheduler_config.max_num_seqs
  248. # Profile memory usage with max_num_sequences sequences and the total
  249. # number of tokens equal to max_num_batched_tokens.
  250. seqs: List[SequenceGroupMetadata] = []
  251. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  252. self.model_config)
  253. if max_mm_tokens > 0:
  254. raise NotImplementedError(
  255. "Multi-modal encoder-decoder models are not supported yet")
  256. batch_size = 0
  257. for group_id in range(max_num_seqs):
  258. seq_len = (max_num_batched_tokens // max_num_seqs +
  259. (group_id < max_num_batched_tokens % max_num_seqs))
  260. batch_size += seq_len
  261. seq_data, _ = self.input_registry \
  262. .dummy_data_for_profiling(self.model_config,
  263. seq_len,
  264. self.mm_registry)
  265. # Having more tokens is over-conservative but otherwise fine
  266. assert len(seq_data.prompt_token_ids) >= seq_len, (
  267. f"Expected at least {seq_len} dummy tokens for profiling, "
  268. f"but got: {len(seq_data.prompt_token_ids)}")
  269. seq = SequenceGroupMetadata(
  270. request_id=str(group_id),
  271. is_prompt=True,
  272. seq_data={group_id: seq_data},
  273. sampling_params=sampling_params,
  274. block_tables=None,
  275. encoder_seq_data=seq_data,
  276. cross_block_table=None,
  277. )
  278. seqs.append(seq)
  279. # Run the model with the dummy inputs.
  280. num_layers = self.model_config.get_num_layers(self.parallel_config)
  281. kv_caches = [None] * num_layers
  282. finished_requests_ids = [seq.request_id for seq in seqs]
  283. model_input = self.prepare_model_input(
  284. seqs, finished_requests_ids=finished_requests_ids)
  285. intermediate_tensors = None
  286. self.execute_model(model_input, kv_caches, intermediate_tensors)
  287. torch.cuda.synchronize()
  288. return
  289. def _prepare_encoder_model_input_tensors(
  290. self,
  291. seq_group_metadata_list: List[SequenceGroupMetadata],
  292. model_input: EncoderDecoderModelInput,
  293. ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
  294. Optional[torch.Tensor]]:
  295. """Helper method to prepare the encoder- and cross-attn-related
  296. model inputs based on a given sequence group. These additional inputs
  297. are used to augment an already-computed `EncoderDecoderModelInput`
  298. data structure which already has decoder-related model inputs
  299. populated.
  300. Sets the following attn_metadata fields:
  301. * `num_encoder_tokens`
  302. * `encoder_seq_lens`
  303. * `encoder_seq_lens_tensor`
  304. * `max_encoder_seq_len`
  305. * `cross_slot_mapping`
  306. * `cross_block_tables`
  307. Constructs a new model inputs data structure, based on
  308. (1) the existing fields in the `model_inputs` argument,
  309. and (2) the following additional fields which are
  310. computed (or in the case of `attn_metadata`, updated)
  311. by this function:
  312. * attn_metadata
  313. * encoder_input_tokens
  314. * encoder_input_positions
  315. Arguments:
  316. * seq_group_metadata_list: list of sequence groups for which to
  317. compute inputs
  318. * model_inputs: model inputs data structure with decoder-oriented
  319. fields already computed.
  320. Return:
  321. * Updated model inputs data structure
  322. """
  323. if len(seq_group_metadata_list) == 0:
  324. return (model_input.attn_metadata, None, None)
  325. # Since we are not supporting chunked prefill either the entire
  326. # batch is prefill or it is decode
  327. is_prompt = seq_group_metadata_list[0].is_prompt
  328. # Build encoder inputs
  329. encoder_seq_lens: List[int] = []
  330. if is_prompt:
  331. # Prefill phase.
  332. cross_block_tables = self._empty_int32_tensor().view(
  333. len(seq_group_metadata_list), -1)
  334. # Extract input tokens/positions, cross-attention slot-mapping,
  335. # & seq len from each sequence group metadata
  336. (
  337. encoder_input_tokens,
  338. encoder_input_positions,
  339. cross_slot_mapping,
  340. ) = (
  341. [],
  342. [],
  343. [],
  344. )
  345. for seq_group_metadata in seq_group_metadata_list:
  346. # Build seq lens
  347. seq_len = seq_group_metadata.encoder_seq_data.get_len()
  348. token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
  349. encoder_seq_lens.append(seq_len)
  350. # Build slot mapping
  351. is_profile_run = (seq_group_metadata.block_tables is None)
  352. if is_profile_run:
  353. # During memory profiling, the block tables are not
  354. # initialized yet. In this case, we just use a dummy
  355. # slot mapping.
  356. # In embeddings, the block tables are {seq_id: None}.
  357. cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
  358. else:
  359. for i in range(0, seq_len):
  360. block_number = seq_group_metadata.cross_block_table[
  361. i // self.block_size]
  362. block_offset = i % self.block_size
  363. slot = block_number * self.block_size + block_offset
  364. cross_slot_mapping.append(slot)
  365. # Build encoder input tokens
  366. encoder_input_tokens.extend(token_ids)
  367. encoder_input_positions.extend(list(range(0, seq_len)))
  368. # Convert tokens/positions & cross-attention
  369. # slot-mapping to encoder input tensors
  370. encoder_input_tokens_tensor = self._list_to_long_tensor(
  371. encoder_input_tokens)
  372. encoder_input_positions_tensor = self._list_to_long_tensor(
  373. encoder_input_positions)
  374. cross_slot_mapping_tensor = self._list_to_long_tensor(
  375. cross_slot_mapping)
  376. else:
  377. # Decode phase.
  378. encoder_input_tokens_tensor = self._empty_long_tensor()
  379. encoder_input_positions_tensor = self._empty_long_tensor()
  380. cross_slot_mapping_tensor = self._empty_long_tensor()
  381. # Extract cross-attention block tables &
  382. # seq len from each sequence group metadata.
  383. # Cross-attention block tables are empty
  384. # during Aphrodite memory profiling.
  385. cross_block_tables = []
  386. for seq_group_metadata in seq_group_metadata_list:
  387. encoder_seq_lens.append(
  388. seq_group_metadata.encoder_seq_data.get_len())
  389. cross_block_table = seq_group_metadata.cross_block_table
  390. cross_block_tables.append([] if (
  391. cross_block_table is None) else cross_block_table)
  392. if (model_input.attn_metadata is not None
  393. and model_input.attn_metadata.use_cuda_graph):
  394. # We will be using CUDA graph replay for this decode.
  395. max_len_of_block_table = self.get_max_block_per_batch()
  396. batch_size = len(encoder_seq_lens)
  397. graph_batch_size = _get_graph_batch_size(batch_size)
  398. assert graph_batch_size >= batch_size
  399. cuda_graph_pad_size = graph_batch_size - batch_size
  400. # extend the cross_block_tables and encoder_seq_lens to match
  401. # the graph_batch_size.
  402. cross_block_tables.extend([[]
  403. for _ in range(cuda_graph_pad_size)
  404. ])
  405. encoder_seq_lens.extend(
  406. itertools.repeat(1, cuda_graph_pad_size))
  407. else:
  408. max_len_of_block_table = max(
  409. len(block_table) for block_table in cross_block_tables)
  410. cross_block_tables = make_tensor_with_pad(
  411. cross_block_tables,
  412. max_len=max_len_of_block_table,
  413. pad=0,
  414. dtype=torch.int32,
  415. device=self.device,
  416. )
  417. # Compute encoder sequence lengths & encoder
  418. # sequence starting offset tensors
  419. max_encoder_seq_len = max(encoder_seq_lens, default=0)
  420. encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
  421. encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
  422. 1,
  423. dtype=torch.int32,
  424. device=self.device)
  425. torch.cumsum(encoder_seq_lens_tensor,
  426. dim=0,
  427. dtype=encoder_seq_start_loc.dtype,
  428. out=encoder_seq_start_loc[1:])
  429. # Update attention metadata with encoder-oriented attributes
  430. attn_metadata = model_input.attn_metadata
  431. assert attn_metadata is not None
  432. (
  433. attn_metadata.num_encoder_tokens,
  434. attn_metadata.encoder_seq_lens,
  435. attn_metadata.encoder_seq_lens_tensor,
  436. attn_metadata.max_encoder_seq_len,
  437. attn_metadata.cross_slot_mapping,
  438. attn_metadata.cross_block_tables,
  439. ) = (
  440. sum(encoder_seq_lens),
  441. encoder_seq_lens,
  442. encoder_seq_lens_tensor,
  443. max_encoder_seq_len,
  444. cross_slot_mapping_tensor,
  445. cross_block_tables,
  446. )
  447. return (attn_metadata, encoder_input_tokens_tensor,
  448. encoder_input_positions_tensor)