enc_dec_model_runner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type, cast
  3. import torch
  4. import torch.distributed
  5. from loguru import logger
  6. from aphrodite.attention.backends.abstract import (AttentionBackend,
  7. AttentionMetadata)
  8. from aphrodite.attention.selector import (_Backend,
  9. get_env_variable_attn_backend,
  10. get_global_forced_attn_backend,
  11. global_force_attn_backend)
  12. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  13. LoRAConfig, ModelConfig, ParallelConfig,
  14. PromptAdapterConfig, SchedulerConfig)
  15. from aphrodite.common.sampling_params import SamplingParams
  16. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  17. SamplerOutput, SequenceGroupMetadata)
  18. from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
  19. make_tensor_with_pad)
  20. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  21. from aphrodite.modeling import SamplingMetadata
  22. from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
  23. from aphrodite.task_handler.model_runner import (
  24. _PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder,
  25. ModelInputForGPUWithSamplingMetadata)
  26. from aphrodite.task_handler.model_runner_base import (
  27. _add_attn_metadata_broadcastable_dict,
  28. _add_sampling_metadata_broadcastable_dict)
  29. from aphrodite.task_handler.utils import assert_enc_dec_mr_supported_scenario
  30. @dataclasses.dataclass(frozen=True)
  31. class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
  32. """
  33. Used by the EncoderDecoderModelRunner.
  34. """
  35. encoder_input_tokens: Optional[torch.Tensor] = None
  36. encoder_input_positions: Optional[torch.Tensor] = None
  37. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  38. tensor_dict = {
  39. "input_tokens": self.input_tokens,
  40. "input_positions": self.input_positions,
  41. "encoder_input_tokens": self.encoder_input_tokens,
  42. "encoder_input_positions": self.encoder_input_positions,
  43. "virtual_engine": self.virtual_engine,
  44. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  45. "finished_requests_ids": self.finished_requests_ids,
  46. }
  47. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  48. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  49. self.sampling_metadata)
  50. return tensor_dict
  51. @classmethod
  52. def from_broadcasted_tensor_dict(
  53. cls,
  54. tensor_dict: Dict[str, Any],
  55. attn_backend: Optional["AttentionBackend"] = None,
  56. ) -> "EncoderDecoderModelInput":
  57. return cast(
  58. EncoderDecoderModelInput,
  59. super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
  60. class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
  61. _model_input_cls: Type[EncoderDecoderModelInput] = (
  62. EncoderDecoderModelInput)
  63. _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
  64. def __init__(
  65. self,
  66. model_config: ModelConfig,
  67. parallel_config: ParallelConfig,
  68. scheduler_config: SchedulerConfig,
  69. device_config: DeviceConfig,
  70. cache_config: CacheConfig,
  71. load_config: LoadConfig,
  72. lora_config: Optional[LoRAConfig],
  73. kv_cache_dtype: Optional[str] = "auto",
  74. is_driver_worker: bool = False,
  75. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  76. input_registry: InputRegistry = INPUT_REGISTRY,
  77. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  78. **kwargs,
  79. ):
  80. '''
  81. EncoderDecoderModelRunner constructor.
  82. `lora_config` and `prompt_adapter_config` are
  83. unused (since these features are not yet supported for encoder/decoder
  84. models) but these arguments are present here for compatibility with
  85. the base-class constructor.
  86. '''
  87. self._maybe_force_supported_attention_backend()
  88. super().__init__(
  89. model_config,
  90. parallel_config,
  91. scheduler_config,
  92. device_config,
  93. cache_config,
  94. load_config,
  95. lora_config=None,
  96. kv_cache_dtype=kv_cache_dtype,
  97. is_driver_worker=is_driver_worker,
  98. **kwargs,
  99. )
  100. # Crash for unsupported encoder/scenarios
  101. assert_enc_dec_mr_supported_scenario(self)
  102. def _maybe_force_supported_attention_backend(self):
  103. '''
  104. Force Aphrodite to use the XFormers attention backend,
  105. which is currently the only supported option.
  106. '''
  107. def raise_backend_err():
  108. # The user has specified an attention backend override
  109. # which is invalid for encoder/decoder models
  110. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
  111. maybe_env_var_forced_backend = get_env_variable_attn_backend()
  112. maybe_global_forced_backend = get_global_forced_attn_backend()
  113. is_forced_by_global = maybe_global_forced_backend is not None
  114. is_forced_by_env_var = maybe_env_var_forced_backend is not None
  115. if not (is_forced_by_global or is_forced_by_env_var):
  116. # The user has not already specified an attention backend
  117. # override
  118. logger.info("EncoderDecoderModelRunner requires "
  119. "XFormers backend; overriding backend "
  120. "auto-selection and forcing XFormers.")
  121. global_force_attn_backend(_Backend.XFORMERS)
  122. elif is_forced_by_global:
  123. # Backend override enforced by global variable takes
  124. # precedence over Aphrodite backend environment variable.
  125. if maybe_global_forced_backend != _Backend.XFORMERS:
  126. raise_backend_err()
  127. elif is_forced_by_env_var:
  128. # Backend override enforced by Aphrodite backend
  129. # environment variable
  130. if maybe_env_var_forced_backend != _Backend.XFORMERS:
  131. raise_backend_err()
  132. def _list_to_int32_tensor(
  133. self,
  134. _list: List[int],
  135. ) -> torch.Tensor:
  136. return torch.tensor(_list, dtype=torch.int32, device=self.device)
  137. def _list_to_long_tensor(
  138. self,
  139. _list: List[int],
  140. ) -> torch.Tensor:
  141. return torch.tensor(_list, dtype=torch.long, device=self.device)
  142. def _empty_int32_tensor(self) -> torch.Tensor:
  143. return self._list_to_int32_tensor([])
  144. def _empty_long_tensor(self) -> torch.Tensor:
  145. return self._list_to_long_tensor([])
  146. @torch.inference_mode()
  147. def execute_model(
  148. self,
  149. model_input: EncoderDecoderModelInput,
  150. kv_caches: List[torch.Tensor],
  151. intermediate_tensors: Optional[IntermediateTensors] = None,
  152. num_steps: int = 1,
  153. ) -> Optional[List[PoolerOutput]]:
  154. if num_steps > 1:
  155. raise ValueError("num_steps > 1 is not supported in "
  156. "EncoderDecoderModelRunner")
  157. model_executable = self.model
  158. seqlen_agnostic_kwargs = {
  159. "finished_requests_ids": model_input.finished_requests_ids,
  160. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  161. } if self.has_seqlen_agnostic else {}
  162. hidden_or_intermediate_states = model_executable(
  163. input_ids=model_input.input_tokens,
  164. positions=model_input.input_positions,
  165. encoder_input_ids=model_input.encoder_input_tokens,
  166. encoder_positions=model_input.encoder_input_positions,
  167. kv_caches=kv_caches,
  168. attn_metadata=model_input.attn_metadata,
  169. intermediate_tensors=intermediate_tensors,
  170. **seqlen_agnostic_kwargs)
  171. logits = self.model.compute_logits(hidden_or_intermediate_states,
  172. model_input.sampling_metadata)
  173. if not self.is_driver_worker:
  174. return []
  175. # Sample the next token.
  176. output: SamplerOutput = self.model.sample(
  177. logits=logits,
  178. sampling_metadata=model_input.sampling_metadata,
  179. )
  180. return [output]
  181. def make_model_input_from_broadcasted_tensor_dict(
  182. self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
  183. return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
  184. tensor_dict,
  185. attn_backend=self.attn_backend,
  186. )
  187. def prepare_model_input(
  188. self,
  189. seq_group_metadata_list: List[SequenceGroupMetadata],
  190. virtual_engine: int = 0,
  191. finished_requests_ids: Optional[List[str]] = None
  192. ) -> EncoderDecoderModelInput:
  193. """Prepare the model input based on a given sequence group, including
  194. metadata for the sampling step.
  195. Since chunked prefill is not supported for encoder/decoder models,
  196. `input_tokens` is assumed to be either entirely prefill tokens or
  197. entirely decode tokens.
  198. """
  199. model_input = self._prepare_model_input_tensors(
  200. seq_group_metadata_list, finished_requests_ids)
  201. (
  202. attn_metadata,
  203. encoder_input_tokens_tensor,
  204. encoder_input_positions_tensor,
  205. ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
  206. model_input))
  207. # Inject attn_metadata encoder/cross-attention fields &
  208. # encoder input tokens/positions into model_input.
  209. # Frozen dataclass fields cannot be modified, so use
  210. # dataclasses.replace to construct a new model input
  211. # instance.
  212. model_input = dataclasses.replace(
  213. model_input,
  214. attn_metadata=attn_metadata,
  215. encoder_input_tokens=encoder_input_tokens_tensor,
  216. encoder_input_positions=encoder_input_positions_tensor,
  217. )
  218. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  219. model_input.seq_lens,
  220. model_input.query_lens,
  221. self.device,
  222. self.pin_memory)
  223. is_prompt = (seq_group_metadata_list[0].is_prompt
  224. if seq_group_metadata_list else None)
  225. return dataclasses.replace(model_input,
  226. sampling_metadata=sampling_metadata,
  227. is_prompt=is_prompt,
  228. virtual_engine=virtual_engine)
  229. @torch.inference_mode()
  230. def profile_run(self) -> None:
  231. # Enable top-k sampling to reflect the accurate memory usage.
  232. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  233. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  234. max_num_seqs = self.scheduler_config.max_num_seqs
  235. # Profile memory usage with max_num_sequences sequences and the total
  236. # number of tokens equal to max_num_batched_tokens.
  237. seqs: List[SequenceGroupMetadata] = []
  238. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  239. self.model_config)
  240. if max_mm_tokens > 0:
  241. raise NotImplementedError(
  242. "Multi-modal encoder-decoder models are not supported yet")
  243. batch_size = 0
  244. for group_id in range(max_num_seqs):
  245. seq_len = (max_num_batched_tokens // max_num_seqs +
  246. (group_id < max_num_batched_tokens % max_num_seqs))
  247. batch_size += seq_len
  248. seq_data, _ = self.input_registry \
  249. .dummy_data_for_profiling(self.model_config,
  250. seq_len,
  251. self.mm_registry)
  252. # Having more tokens is over-conservative but otherwise fine
  253. assert len(seq_data.prompt_token_ids) >= seq_len, (
  254. f"Expected at least {seq_len} dummy tokens for profiling, "
  255. f"but got: {len(seq_data.prompt_token_ids)}")
  256. seq = SequenceGroupMetadata(
  257. request_id=str(group_id),
  258. is_prompt=True,
  259. seq_data={group_id: seq_data},
  260. sampling_params=sampling_params,
  261. block_tables=None,
  262. encoder_seq_data=seq_data,
  263. cross_block_table=None,
  264. )
  265. seqs.append(seq)
  266. # Run the model with the dummy inputs.
  267. num_layers = self.model_config.get_num_layers(self.parallel_config)
  268. kv_caches = [None] * num_layers
  269. finished_requests_ids = [seq.request_id for seq in seqs]
  270. model_input = self.prepare_model_input(
  271. seqs, finished_requests_ids=finished_requests_ids)
  272. intermediate_tensors = None
  273. self.execute_model(model_input, kv_caches, intermediate_tensors)
  274. torch.cuda.synchronize()
  275. return
  276. def _prepare_encoder_model_input_tensors(
  277. self,
  278. seq_group_metadata_list: List[SequenceGroupMetadata],
  279. model_input: EncoderDecoderModelInput,
  280. ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
  281. Optional[torch.Tensor]]:
  282. """Helper method to prepare the encoder- and cross-attn-related
  283. model inputs based on a given sequence group. These additional inputs
  284. are used to augment an already-computed `EncoderDecoderModelInput`
  285. data structure which already has decoder-related model inputs
  286. populated.
  287. Sets the following attn_metadata fields:
  288. * `num_encoder_tokens`
  289. * `encoder_seq_lens`
  290. * `encoder_seq_lens_tensor`
  291. * `max_encoder_seq_len`
  292. * `cross_slot_mapping`
  293. * `cross_block_tables`
  294. Constructs a new model inputs data structure, based on
  295. (1) the existing fields in the `model_inputs` argument,
  296. and (2) the following additional fields which are
  297. computed (or in the case of `attn_metadata`, updated)
  298. by this function:
  299. * attn_metadata
  300. * encoder_input_tokens
  301. * encoder_input_positions
  302. Arguments:
  303. * seq_group_metadata_list: list of sequence groups for which to
  304. compute inputs
  305. * model_inputs: model inputs data structure with decoder-oriented
  306. fields already computed.
  307. Return:
  308. * Updated model inputs data structure
  309. """
  310. if len(seq_group_metadata_list) == 0:
  311. return (model_input.attn_metadata, None, None)
  312. # Since we are not supporting chunked prefill either the entire
  313. # batch is prefill or it is decode
  314. is_prompt = seq_group_metadata_list[0].is_prompt
  315. # Build encoder inputs
  316. encoder_seq_lens: List[int] = []
  317. if is_prompt:
  318. # Prefill phase.
  319. cross_block_tables = self._empty_int32_tensor().view(
  320. len(seq_group_metadata_list), -1)
  321. # Extract input tokens/positions, cross-attention slot-mapping,
  322. # & seq len from each sequence group metadata
  323. (
  324. encoder_input_tokens,
  325. encoder_input_positions,
  326. cross_slot_mapping,
  327. ) = (
  328. [],
  329. [],
  330. [],
  331. )
  332. for seq_group_metadata in seq_group_metadata_list:
  333. # Build seq lens
  334. seq_len = seq_group_metadata.encoder_seq_data.get_len()
  335. token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
  336. encoder_seq_lens.append(seq_len)
  337. # Build slot mapping
  338. is_profile_run = (seq_group_metadata.block_tables is None)
  339. if is_profile_run:
  340. # During memory profiling, the block tables are not
  341. # initialized yet. In this case, we just use a dummy
  342. # slot mapping.
  343. # In embeddings, the block tables are {seq_id: None}.
  344. cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  345. else:
  346. for i in range(0, seq_len):
  347. block_number = seq_group_metadata.cross_block_table[
  348. i // self.block_size]
  349. block_offset = i % self.block_size
  350. slot = block_number * self.block_size + block_offset
  351. cross_slot_mapping.append(slot)
  352. # Build encoder input tokens
  353. encoder_input_tokens.extend(token_ids)
  354. encoder_input_positions.extend(list(range(0, seq_len)))
  355. # Convert tokens/positions & cross-attention
  356. # slot-mapping to encoder input tensors
  357. encoder_input_tokens_tensor = self._list_to_long_tensor(
  358. encoder_input_tokens)
  359. encoder_input_positions_tensor = self._list_to_long_tensor(
  360. encoder_input_positions)
  361. cross_slot_mapping_tensor = self._list_to_long_tensor(
  362. cross_slot_mapping)
  363. else:
  364. # Decode phase.
  365. encoder_input_tokens_tensor = self._empty_long_tensor()
  366. encoder_input_positions_tensor = self._empty_long_tensor()
  367. cross_slot_mapping_tensor = self._empty_long_tensor()
  368. # Extract cross-attention block tables &
  369. # seq len from each sequence group metadata.
  370. # Cross-attention block tables are empty
  371. # during Aphrodite memory profiling.
  372. cross_block_tables = []
  373. for seq_group_metadata in seq_group_metadata_list:
  374. encoder_seq_lens.append(
  375. seq_group_metadata.encoder_seq_data.get_len())
  376. cross_block_table = seq_group_metadata.cross_block_table
  377. cross_block_tables.append([] if (
  378. cross_block_table is None) else cross_block_table)
  379. # Convert cross-attention block tables to encoder input tensor
  380. cross_block_tables = make_tensor_with_pad(
  381. cross_block_tables,
  382. max_len=max(
  383. len(block_table) for block_table in cross_block_tables),
  384. pad=0,
  385. dtype=torch.int32,
  386. device=self.device,
  387. )
  388. # Compute encoder sequence lengths & encoder
  389. # sequence starting offset tensors
  390. max_encoder_seq_len = max(encoder_seq_lens, default=0)
  391. encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
  392. encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
  393. 1,
  394. dtype=torch.int32,
  395. device=self.device)
  396. torch.cumsum(encoder_seq_lens_tensor,
  397. dim=0,
  398. dtype=encoder_seq_start_loc.dtype,
  399. out=encoder_seq_start_loc[1:])
  400. # Update attention metadata with encoder-oriented attributes
  401. attn_metadata = model_input.attn_metadata
  402. assert attn_metadata is not None
  403. (
  404. attn_metadata.num_encoder_tokens,
  405. attn_metadata.encoder_seq_lens,
  406. attn_metadata.encoder_seq_lens_tensor,
  407. attn_metadata.max_encoder_seq_len,
  408. attn_metadata.cross_slot_mapping,
  409. attn_metadata.cross_block_tables,
  410. ) = (
  411. sum(encoder_seq_lens),
  412. encoder_seq_lens,
  413. encoder_seq_lens_tensor,
  414. max_encoder_seq_len,
  415. cross_slot_mapping_tensor,
  416. cross_block_tables,
  417. )
  418. return (attn_metadata, encoder_input_tokens_tensor,
  419. encoder_input_positions_tensor)