enc_dec_model_runner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type, cast
  3. import torch
  4. import torch.distributed
  5. from loguru import logger
  6. from aphrodite.attention.backends.abstract import (AttentionBackend,
  7. AttentionMetadata)
  8. from aphrodite.attention.backends.utils import PAD_SLOT_ID
  9. from aphrodite.attention.selector import (_Backend,
  10. get_env_variable_attn_backend,
  11. get_global_forced_attn_backend,
  12. global_force_attn_backend)
  13. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  14. LoRAConfig, ModelConfig, ParallelConfig,
  15. PromptAdapterConfig, SchedulerConfig)
  16. from aphrodite.common.sampling_params import SamplingParams
  17. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  18. SequenceGroupMetadata)
  19. from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
  20. make_tensor_with_pad)
  21. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  22. from aphrodite.modeling.layers.sampler import SamplerOutput
  23. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  24. from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
  25. from aphrodite.worker.model_runner import (
  26. GPUModelRunnerBase, ModelInputForGPUBuilder,
  27. ModelInputForGPUWithSamplingMetadata)
  28. from aphrodite.worker.model_runner_base import (
  29. _add_attn_metadata_broadcastable_dict,
  30. _add_sampling_metadata_broadcastable_dict)
  31. from aphrodite.worker.utils import assert_enc_dec_mr_supported_scenario
  32. @dataclasses.dataclass(frozen=True)
  33. class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
  34. """
  35. Used by the EncoderDecoderModelRunner.
  36. """
  37. encoder_input_tokens: Optional[torch.Tensor] = None
  38. encoder_input_positions: Optional[torch.Tensor] = None
  39. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  40. tensor_dict = {
  41. "input_tokens": self.input_tokens,
  42. "input_positions": self.input_positions,
  43. "encoder_input_tokens": self.encoder_input_tokens,
  44. "encoder_input_positions": self.encoder_input_positions,
  45. "virtual_engine": self.virtual_engine,
  46. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  47. "finished_requests_ids": self.finished_requests_ids,
  48. }
  49. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  50. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  51. self.sampling_metadata)
  52. return tensor_dict
  53. @classmethod
  54. def from_broadcasted_tensor_dict(
  55. cls,
  56. tensor_dict: Dict[str, Any],
  57. attn_backend: Optional["AttentionBackend"] = None,
  58. ) -> "EncoderDecoderModelInput":
  59. return cast(
  60. EncoderDecoderModelInput,
  61. super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
  62. class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
  63. _model_input_cls: Type[EncoderDecoderModelInput] = (
  64. EncoderDecoderModelInput)
  65. _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
  66. def __init__(
  67. self,
  68. model_config: ModelConfig,
  69. parallel_config: ParallelConfig,
  70. scheduler_config: SchedulerConfig,
  71. device_config: DeviceConfig,
  72. cache_config: CacheConfig,
  73. load_config: LoadConfig,
  74. lora_config: Optional[LoRAConfig],
  75. kv_cache_dtype: Optional[str] = "auto",
  76. is_driver_worker: bool = False,
  77. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  78. input_registry: InputRegistry = INPUT_REGISTRY,
  79. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  80. **kwargs,
  81. ):
  82. '''
  83. EncoderDecoderModelRunner constructor.
  84. `lora_config` and `prompt_adapter_config` are
  85. unused (since these features are not yet supported for encoder/decoder
  86. models) but these arguments are present here for compatibility with
  87. the base-class constructor.
  88. '''
  89. self._maybe_force_supported_attention_backend()
  90. super().__init__(
  91. model_config,
  92. parallel_config,
  93. scheduler_config,
  94. device_config,
  95. cache_config,
  96. load_config,
  97. lora_config=None,
  98. kv_cache_dtype=kv_cache_dtype,
  99. is_driver_worker=is_driver_worker,
  100. **kwargs,
  101. )
  102. # Crash for unsupported encoder/scenarios
  103. assert_enc_dec_mr_supported_scenario(self)
  104. def _maybe_force_supported_attention_backend(self):
  105. '''
  106. Force Aphrodite to use the XFormers attention backend,
  107. which is currently the only supported option.
  108. '''
  109. def raise_backend_err():
  110. # The user has specified an attention backend override
  111. # which is invalid for encoder/decoder models
  112. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
  113. maybe_env_var_forced_backend = get_env_variable_attn_backend()
  114. maybe_global_forced_backend = get_global_forced_attn_backend()
  115. is_forced_by_global = maybe_global_forced_backend is not None
  116. is_forced_by_env_var = maybe_env_var_forced_backend is not None
  117. if not (is_forced_by_global or is_forced_by_env_var):
  118. # The user has not already specified an attention backend
  119. # override
  120. logger.info("EncoderDecoderModelRunner requires "
  121. "XFormers backend; overriding backend "
  122. "auto-selection and forcing XFormers.")
  123. global_force_attn_backend(_Backend.XFORMERS)
  124. elif is_forced_by_global:
  125. # Backend override enforced by global variable takes
  126. # precedence over Aphrodite backend environment variable.
  127. if maybe_global_forced_backend != _Backend.XFORMERS:
  128. raise_backend_err()
  129. elif is_forced_by_env_var:
  130. # Backend override enforced by Aphrodite backend
  131. # environment variable
  132. if maybe_env_var_forced_backend != _Backend.XFORMERS:
  133. raise_backend_err()
  134. def _list_to_int32_tensor(
  135. self,
  136. _list: List[int],
  137. ) -> torch.Tensor:
  138. return torch.tensor(_list, dtype=torch.int32, device=self.device)
  139. def _list_to_long_tensor(
  140. self,
  141. _list: List[int],
  142. ) -> torch.Tensor:
  143. return torch.tensor(_list, dtype=torch.long, device=self.device)
  144. def _empty_int32_tensor(self) -> torch.Tensor:
  145. return self._list_to_int32_tensor([])
  146. def _empty_long_tensor(self) -> torch.Tensor:
  147. return self._list_to_long_tensor([])
  148. @torch.inference_mode()
  149. def execute_model(
  150. self,
  151. model_input: EncoderDecoderModelInput,
  152. kv_caches: List[torch.Tensor],
  153. intermediate_tensors: Optional[IntermediateTensors] = None,
  154. num_steps: int = 1,
  155. ) -> Optional[List[PoolerOutput]]:
  156. if num_steps > 1:
  157. raise ValueError("num_steps > 1 is not supported in "
  158. "EncoderDecoderModelRunner")
  159. model_executable = self.model
  160. seqlen_agnostic_kwargs = {
  161. "finished_requests_ids": model_input.finished_requests_ids,
  162. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  163. } if self.has_seqlen_agnostic else {}
  164. hidden_or_intermediate_states = model_executable(
  165. input_ids=model_input.input_tokens,
  166. positions=model_input.input_positions,
  167. encoder_input_ids=model_input.encoder_input_tokens,
  168. encoder_positions=model_input.encoder_input_positions,
  169. kv_caches=kv_caches,
  170. attn_metadata=model_input.attn_metadata,
  171. intermediate_tensors=intermediate_tensors,
  172. **seqlen_agnostic_kwargs)
  173. logits = self.model.compute_logits(hidden_or_intermediate_states,
  174. model_input.sampling_metadata)
  175. if not self.is_driver_worker:
  176. return []
  177. # Sample the next token.
  178. output: SamplerOutput = self.model.sample(
  179. logits=logits,
  180. sampling_metadata=model_input.sampling_metadata,
  181. )
  182. return [output]
  183. def make_model_input_from_broadcasted_tensor_dict(
  184. self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
  185. return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
  186. tensor_dict,
  187. attn_backend=self.attn_backend,
  188. )
  189. def prepare_model_input(
  190. self,
  191. seq_group_metadata_list: List[SequenceGroupMetadata],
  192. virtual_engine: int = 0,
  193. finished_requests_ids: Optional[List[str]] = None
  194. ) -> EncoderDecoderModelInput:
  195. """Prepare the model input based on a given sequence group, including
  196. metadata for the sampling step.
  197. Since chunked prefill is not supported for encoder/decoder models,
  198. `input_tokens` is assumed to be either entirely prefill tokens or
  199. entirely decode tokens.
  200. """
  201. model_input = self._prepare_model_input_tensors(
  202. seq_group_metadata_list, finished_requests_ids)
  203. (
  204. attn_metadata,
  205. encoder_input_tokens_tensor,
  206. encoder_input_positions_tensor,
  207. ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
  208. model_input))
  209. # Inject attn_metadata encoder/cross-attention fields &
  210. # encoder input tokens/positions into model_input.
  211. # Frozen dataclass fields cannot be modified, so use
  212. # dataclasses.replace to construct a new model input
  213. # instance.
  214. model_input = dataclasses.replace(
  215. model_input,
  216. attn_metadata=attn_metadata,
  217. encoder_input_tokens=encoder_input_tokens_tensor,
  218. encoder_input_positions=encoder_input_positions_tensor,
  219. )
  220. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  221. model_input.seq_lens,
  222. model_input.query_lens,
  223. self.device,
  224. self.pin_memory)
  225. is_prompt = (seq_group_metadata_list[0].is_prompt
  226. if seq_group_metadata_list else None)
  227. return dataclasses.replace(model_input,
  228. sampling_metadata=sampling_metadata,
  229. is_prompt=is_prompt,
  230. virtual_engine=virtual_engine)
  231. @torch.inference_mode()
  232. def profile_run(self) -> None:
  233. # Enable top-k sampling to reflect the accurate memory usage.
  234. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  235. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  236. max_num_seqs = self.scheduler_config.max_num_seqs
  237. # Profile memory usage with max_num_sequences sequences and the total
  238. # number of tokens equal to max_num_batched_tokens.
  239. seqs: List[SequenceGroupMetadata] = []
  240. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  241. self.model_config)
  242. if max_mm_tokens > 0:
  243. raise NotImplementedError(
  244. "Multi-modal encoder-decoder models are not supported yet")
  245. batch_size = 0
  246. for group_id in range(max_num_seqs):
  247. seq_len = (max_num_batched_tokens // max_num_seqs +
  248. (group_id < max_num_batched_tokens % max_num_seqs))
  249. batch_size += seq_len
  250. seq_data, _ = self.input_registry \
  251. .dummy_data_for_profiling(self.model_config,
  252. seq_len,
  253. self.mm_registry)
  254. # Having more tokens is over-conservative but otherwise fine
  255. assert len(seq_data.prompt_token_ids) >= seq_len, (
  256. f"Expected at least {seq_len} dummy tokens for profiling, "
  257. f"but got: {len(seq_data.prompt_token_ids)}")
  258. seq = SequenceGroupMetadata(
  259. request_id=str(group_id),
  260. is_prompt=True,
  261. seq_data={group_id: seq_data},
  262. sampling_params=sampling_params,
  263. block_tables=None,
  264. encoder_seq_data=seq_data,
  265. cross_block_table=None,
  266. )
  267. seqs.append(seq)
  268. # Run the model with the dummy inputs.
  269. num_layers = self.model_config.get_num_layers(self.parallel_config)
  270. kv_caches = [None] * num_layers
  271. finished_requests_ids = [seq.request_id for seq in seqs]
  272. model_input = self.prepare_model_input(
  273. seqs, finished_requests_ids=finished_requests_ids)
  274. intermediate_tensors = None
  275. self.execute_model(model_input, kv_caches, intermediate_tensors)
  276. torch.cuda.synchronize()
  277. return
  278. def _prepare_encoder_model_input_tensors(
  279. self,
  280. seq_group_metadata_list: List[SequenceGroupMetadata],
  281. model_input: EncoderDecoderModelInput,
  282. ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
  283. Optional[torch.Tensor]]:
  284. """Helper method to prepare the encoder- and cross-attn-related
  285. model inputs based on a given sequence group. These additional inputs
  286. are used to augment an already-computed `EncoderDecoderModelInput`
  287. data structure which already has decoder-related model inputs
  288. populated.
  289. Sets the following attn_metadata fields:
  290. * `num_encoder_tokens`
  291. * `encoder_seq_lens`
  292. * `encoder_seq_lens_tensor`
  293. * `max_encoder_seq_len`
  294. * `cross_slot_mapping`
  295. * `cross_block_tables`
  296. Constructs a new model inputs data structure, based on
  297. (1) the existing fields in the `model_inputs` argument,
  298. and (2) the following additional fields which are
  299. computed (or in the case of `attn_metadata`, updated)
  300. by this function:
  301. * attn_metadata
  302. * encoder_input_tokens
  303. * encoder_input_positions
  304. Arguments:
  305. * seq_group_metadata_list: list of sequence groups for which to
  306. compute inputs
  307. * model_inputs: model inputs data structure with decoder-oriented
  308. fields already computed.
  309. Return:
  310. * Updated model inputs data structure
  311. """
  312. if len(seq_group_metadata_list) == 0:
  313. return (model_input.attn_metadata, None, None)
  314. # Since we are not supporting chunked prefill either the entire
  315. # batch is prefill or it is decode
  316. is_prompt = seq_group_metadata_list[0].is_prompt
  317. # Build encoder inputs
  318. encoder_seq_lens: List[int] = []
  319. if is_prompt:
  320. # Prefill phase.
  321. cross_block_tables = self._empty_int32_tensor().view(
  322. len(seq_group_metadata_list), -1)
  323. # Extract input tokens/positions, cross-attention slot-mapping,
  324. # & seq len from each sequence group metadata
  325. (
  326. encoder_input_tokens,
  327. encoder_input_positions,
  328. cross_slot_mapping,
  329. ) = (
  330. [],
  331. [],
  332. [],
  333. )
  334. for seq_group_metadata in seq_group_metadata_list:
  335. # Build seq lens
  336. seq_len = seq_group_metadata.encoder_seq_data.get_len()
  337. token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
  338. encoder_seq_lens.append(seq_len)
  339. # Build slot mapping
  340. is_profile_run = (seq_group_metadata.block_tables is None)
  341. if is_profile_run:
  342. # During memory profiling, the block tables are not
  343. # initialized yet. In this case, we just use a dummy
  344. # slot mapping.
  345. # In embeddings, the block tables are {seq_id: None}.
  346. cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
  347. else:
  348. for i in range(0, seq_len):
  349. block_number = seq_group_metadata.cross_block_table[
  350. i // self.block_size]
  351. block_offset = i % self.block_size
  352. slot = block_number * self.block_size + block_offset
  353. cross_slot_mapping.append(slot)
  354. # Build encoder input tokens
  355. encoder_input_tokens.extend(token_ids)
  356. encoder_input_positions.extend(list(range(0, seq_len)))
  357. # Convert tokens/positions & cross-attention
  358. # slot-mapping to encoder input tensors
  359. encoder_input_tokens_tensor = self._list_to_long_tensor(
  360. encoder_input_tokens)
  361. encoder_input_positions_tensor = self._list_to_long_tensor(
  362. encoder_input_positions)
  363. cross_slot_mapping_tensor = self._list_to_long_tensor(
  364. cross_slot_mapping)
  365. else:
  366. # Decode phase.
  367. encoder_input_tokens_tensor = self._empty_long_tensor()
  368. encoder_input_positions_tensor = self._empty_long_tensor()
  369. cross_slot_mapping_tensor = self._empty_long_tensor()
  370. # Extract cross-attention block tables &
  371. # seq len from each sequence group metadata.
  372. # Cross-attention block tables are empty
  373. # during Aphrodite memory profiling.
  374. cross_block_tables = []
  375. for seq_group_metadata in seq_group_metadata_list:
  376. encoder_seq_lens.append(
  377. seq_group_metadata.encoder_seq_data.get_len())
  378. cross_block_table = seq_group_metadata.cross_block_table
  379. cross_block_tables.append([] if (
  380. cross_block_table is None) else cross_block_table)
  381. # Convert cross-attention block tables to encoder input tensor
  382. cross_block_tables = make_tensor_with_pad(
  383. cross_block_tables,
  384. max_len=max(
  385. len(block_table) for block_table in cross_block_tables),
  386. pad=0,
  387. dtype=torch.int32,
  388. device=self.device,
  389. )
  390. # Compute encoder sequence lengths & encoder
  391. # sequence starting offset tensors
  392. max_encoder_seq_len = max(encoder_seq_lens, default=0)
  393. encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
  394. encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
  395. 1,
  396. dtype=torch.int32,
  397. device=self.device)
  398. torch.cumsum(encoder_seq_lens_tensor,
  399. dim=0,
  400. dtype=encoder_seq_start_loc.dtype,
  401. out=encoder_seq_start_loc[1:])
  402. # Update attention metadata with encoder-oriented attributes
  403. attn_metadata = model_input.attn_metadata
  404. assert attn_metadata is not None
  405. (
  406. attn_metadata.num_encoder_tokens,
  407. attn_metadata.encoder_seq_lens,
  408. attn_metadata.encoder_seq_lens_tensor,
  409. attn_metadata.max_encoder_seq_len,
  410. attn_metadata.cross_slot_mapping,
  411. attn_metadata.cross_block_tables,
  412. ) = (
  413. sum(encoder_seq_lens),
  414. encoder_seq_lens,
  415. encoder_seq_lens_tensor,
  416. max_encoder_seq_len,
  417. cross_slot_mapping_tensor,
  418. cross_block_tables,
  419. )
  420. return (attn_metadata, encoder_input_tokens_tensor,
  421. encoder_input_positions_tensor)