enc_dec_model_runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type, cast
  3. import torch
  4. import torch.distributed
  5. from loguru import logger
  6. from aphrodite.attention.backends.abstract import (AttentionBackend,
  7. AttentionMetadata)
  8. from aphrodite.attention.selector import (_Backend,
  9. get_env_variable_attn_backend,
  10. get_global_forced_attn_backend,
  11. global_force_attn_backend)
  12. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  13. LoRAConfig, ModelConfig, MultiModalConfig,
  14. ParallelConfig, PromptAdapterConfig,
  15. SchedulerConfig)
  16. from aphrodite.common.sampling_params import SamplingParams
  17. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  18. SamplerOutput, SequenceGroupMetadata)
  19. from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
  20. make_tensor_with_pad)
  21. from aphrodite.inputs import INPUT_REGISTRY
  22. from aphrodite.modeling import SamplingMetadata
  23. from aphrodite.task_handler.model_runner import (
  24. _PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder,
  25. ModelInputForGPUWithSamplingMetadata)
  26. from aphrodite.task_handler.model_runner_base import (
  27. _add_attn_metadata_broadcastable_dict,
  28. _add_sampling_metadata_broadcastable_dict)
  29. from aphrodite.task_handler.utils import assert_enc_dec_mr_supported_scenario
  30. @dataclasses.dataclass(frozen=True)
  31. class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
  32. """
  33. Used by the EncoderDecoderModelRunner.
  34. """
  35. encoder_input_tokens: Optional[torch.Tensor] = None
  36. encoder_input_positions: Optional[torch.Tensor] = None
  37. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  38. tensor_dict = {
  39. "input_tokens": self.input_tokens,
  40. "input_positions": self.input_positions,
  41. "encoder_input_tokens": self.encoder_input_tokens,
  42. "encoder_input_positions": self.encoder_input_positions,
  43. "virtual_engine": self.virtual_engine,
  44. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  45. "finished_requests_ids": self.finished_requests_ids,
  46. }
  47. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  48. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  49. self.sampling_metadata)
  50. return tensor_dict
  51. @classmethod
  52. def from_broadcasted_tensor_dict(
  53. cls,
  54. tensor_dict: Dict[str, Any],
  55. attn_backend: Optional["AttentionBackend"] = None,
  56. ) -> "EncoderDecoderModelInput":
  57. return cast(
  58. EncoderDecoderModelInput,
  59. super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
  60. class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
  61. _model_input_cls: Type[EncoderDecoderModelInput] = (
  62. EncoderDecoderModelInput)
  63. _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
  64. def __init__(
  65. self,
  66. model_config: ModelConfig,
  67. parallel_config: ParallelConfig,
  68. scheduler_config: SchedulerConfig,
  69. device_config: DeviceConfig,
  70. cache_config: CacheConfig,
  71. load_config: LoadConfig,
  72. lora_config: Optional[LoRAConfig],
  73. kv_cache_dtype: Optional[str] = "auto",
  74. is_driver_worker: bool = False,
  75. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  76. multimodal_config: Optional[MultiModalConfig] = None,
  77. **kwargs,
  78. ):
  79. '''
  80. EncoderDecoderModelRunner constructor.
  81. `lora_config`, `multimodal_config`, and prompt_adapter_config are
  82. unused (since these features are not yet supported for encoder/decoder
  83. models) but these arguments are present here for compatibility with
  84. the base-class constructor.
  85. '''
  86. self._maybe_force_supported_attention_backend()
  87. super().__init__(
  88. model_config,
  89. parallel_config,
  90. scheduler_config,
  91. device_config,
  92. cache_config,
  93. load_config,
  94. lora_config=None,
  95. kv_cache_dtype=kv_cache_dtype,
  96. is_driver_worker=is_driver_worker,
  97. **kwargs,
  98. )
  99. # Crash for unsupported encoder/scenarios
  100. assert_enc_dec_mr_supported_scenario(self)
  101. def _maybe_force_supported_attention_backend(self):
  102. '''
  103. Force Aphrodite to use the XFormers attention backend,
  104. which is currently the only supported option.
  105. '''
  106. def raise_backend_err():
  107. # The user has specified an attention backend override
  108. # which is invalid for encoder/decoder models
  109. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
  110. maybe_env_var_forced_backend = get_env_variable_attn_backend()
  111. maybe_global_forced_backend = get_global_forced_attn_backend()
  112. is_forced_by_global = maybe_global_forced_backend is not None
  113. is_forced_by_env_var = maybe_env_var_forced_backend is not None
  114. if not (is_forced_by_global or is_forced_by_env_var):
  115. # The user has not already specified an attention backend
  116. # override
  117. logger.info("EncoderDecoderModelRunner requires "
  118. "XFormers backend; overriding backend "
  119. "auto-selection and forcing XFormers.")
  120. global_force_attn_backend(_Backend.XFORMERS)
  121. elif is_forced_by_global:
  122. # Backend override enforced by global variable takes
  123. # precedence over Aphrodite backend environment variable.
  124. if maybe_global_forced_backend != _Backend.XFORMERS:
  125. raise_backend_err()
  126. elif is_forced_by_env_var:
  127. # Backend override enforced by Aphrodite backend
  128. # environment variable
  129. if maybe_env_var_forced_backend != _Backend.XFORMERS:
  130. raise_backend_err()
  131. def _list_to_int32_tensor(
  132. self,
  133. _list: List[int],
  134. ) -> torch.Tensor:
  135. return torch.tensor(_list, dtype=torch.int32, device=self.device)
  136. def _list_to_long_tensor(
  137. self,
  138. _list: List[int],
  139. ) -> torch.Tensor:
  140. return torch.tensor(_list, dtype=torch.long, device=self.device)
  141. def _empty_int32_tensor(self) -> torch.Tensor:
  142. return self._list_to_int32_tensor([])
  143. def _empty_long_tensor(self) -> torch.Tensor:
  144. return self._list_to_long_tensor([])
  145. @torch.inference_mode()
  146. def execute_model(
  147. self,
  148. model_input: EncoderDecoderModelInput,
  149. kv_caches: List[torch.Tensor],
  150. intermediate_tensors: Optional[IntermediateTensors] = None,
  151. num_steps: int = 1,
  152. ) -> Optional[List[PoolerOutput]]:
  153. if num_steps > 1:
  154. raise ValueError("num_steps > 1 is not supported in "
  155. "EncoderDecoderModelRunner")
  156. model_executable = self.model
  157. seqlen_agnostic_kwargs = {
  158. "finished_requests_ids": model_input.finished_requests_ids,
  159. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  160. } if self.has_seqlen_agnostic else {}
  161. hidden_or_intermediate_states = model_executable(
  162. input_ids=model_input.input_tokens,
  163. positions=model_input.input_positions,
  164. encoder_input_ids=model_input.encoder_input_tokens,
  165. encoder_positions=model_input.encoder_input_positions,
  166. kv_caches=kv_caches,
  167. attn_metadata=model_input.attn_metadata,
  168. intermediate_tensors=intermediate_tensors,
  169. **seqlen_agnostic_kwargs)
  170. logits = self.model.compute_logits(hidden_or_intermediate_states,
  171. model_input.sampling_metadata)
  172. if not self.is_driver_worker:
  173. return []
  174. # Sample the next token.
  175. output: SamplerOutput = self.model.sample(
  176. logits=logits,
  177. sampling_metadata=model_input.sampling_metadata,
  178. )
  179. return [output]
  180. def make_model_input_from_broadcasted_tensor_dict(
  181. self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
  182. return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
  183. tensor_dict,
  184. attn_backend=self.attn_backend,
  185. )
  186. def prepare_model_input(
  187. self,
  188. seq_group_metadata_list: List[SequenceGroupMetadata],
  189. virtual_engine: int = 0,
  190. finished_requests_ids: Optional[List[str]] = None
  191. ) -> EncoderDecoderModelInput:
  192. """Prepare the model input based on a given sequence group, including
  193. metadata for the sampling step.
  194. Since chunked prefill is not supported for encoder/decoder models,
  195. `input_tokens` is assumed to be either entirely prefill tokens or
  196. entirely decode tokens.
  197. """
  198. model_input = self._prepare_model_input_tensors(
  199. seq_group_metadata_list, finished_requests_ids)
  200. (
  201. attn_metadata,
  202. encoder_input_tokens_tensor,
  203. encoder_input_positions_tensor,
  204. ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
  205. model_input))
  206. # Inject attn_metadata encoder/cross-attention fields &
  207. # encoder input tokens/positions into model_input.
  208. # Frozen dataclass fields cannot be modified, so use
  209. # dataclasses.replace to construct a new model input
  210. # instance.
  211. model_input = dataclasses.replace(
  212. model_input,
  213. attn_metadata=attn_metadata,
  214. encoder_input_tokens=encoder_input_tokens_tensor,
  215. encoder_input_positions=encoder_input_positions_tensor,
  216. )
  217. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  218. model_input.seq_lens,
  219. model_input.query_lens,
  220. self.device,
  221. self.pin_memory)
  222. is_prompt = (seq_group_metadata_list[0].is_prompt
  223. if seq_group_metadata_list else None)
  224. return dataclasses.replace(model_input,
  225. sampling_metadata=sampling_metadata,
  226. is_prompt=is_prompt,
  227. virtual_engine=virtual_engine)
  228. @torch.inference_mode()
  229. def profile_run(self) -> None:
  230. # Enable top-k sampling to reflect the accurate memory usage.
  231. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  232. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  233. max_num_seqs = self.scheduler_config.max_num_seqs
  234. # Profile memory usage with max_num_sequences sequences and the total
  235. # number of tokens equal to max_num_batched_tokens.
  236. seqs: List[SequenceGroupMetadata] = []
  237. model_config = self.model_config
  238. batch_size = 0
  239. for group_id in range(max_num_seqs):
  240. seq_len = (max_num_batched_tokens // max_num_seqs +
  241. (group_id < max_num_batched_tokens % max_num_seqs))
  242. batch_size += seq_len
  243. seq_data, _ = INPUT_REGISTRY \
  244. .dummy_data_for_profiling(model_config, seq_len)
  245. # Having more tokens is over-conservative but otherwise fine
  246. assert len(seq_data.prompt_token_ids) >= seq_len, (
  247. f"Expected at least {seq_len} dummy tokens for profiling, "
  248. f"but got: {len(seq_data.prompt_token_ids)}")
  249. seq = SequenceGroupMetadata(
  250. request_id=str(group_id),
  251. is_prompt=True,
  252. seq_data={group_id: seq_data},
  253. sampling_params=sampling_params,
  254. block_tables=None,
  255. encoder_seq_data=seq_data,
  256. cross_block_table=None,
  257. )
  258. seqs.append(seq)
  259. # Run the model with the dummy inputs.
  260. num_layers = self.model_config.get_num_layers(self.parallel_config)
  261. kv_caches = [None] * num_layers
  262. finished_requests_ids = [seq.request_id for seq in seqs]
  263. model_input = self.prepare_model_input(
  264. seqs, finished_requests_ids=finished_requests_ids)
  265. intermediate_tensors = None
  266. self.execute_model(model_input, kv_caches, intermediate_tensors)
  267. torch.cuda.synchronize()
  268. return
  269. def _prepare_encoder_model_input_tensors(
  270. self,
  271. seq_group_metadata_list: List[SequenceGroupMetadata],
  272. model_input: EncoderDecoderModelInput,
  273. ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
  274. Optional[torch.Tensor]]:
  275. """Helper method to prepare the encoder- and cross-attn-related
  276. model inputs based on a given sequence group. These additional inputs
  277. are used to augment an already-computed `EncoderDecoderModelInput`
  278. data structure which already has decoder-related model inputs
  279. populated.
  280. Sets the following attn_metadata fields:
  281. * `num_encoder_tokens`
  282. * `encoder_seq_lens`
  283. * `encoder_seq_lens_tensor`
  284. * `max_encoder_seq_len`
  285. * `cross_slot_mapping`
  286. * `cross_block_tables`
  287. Constructs a new model inputs data structure, based on
  288. (1) the existing fields in the `model_inputs` argument,
  289. and (2) the following additional fields which are
  290. computed (or in the case of `attn_metadata`, updated)
  291. by this function:
  292. * attn_metadata
  293. * encoder_input_tokens
  294. * encoder_input_positions
  295. Arguments:
  296. * seq_group_metadata_list: list of sequence groups for which to
  297. compute inputs
  298. * model_inputs: model inputs data structure with decoder-oriented
  299. fields already computed.
  300. Return:
  301. * Updated model inputs data structure
  302. """
  303. if len(seq_group_metadata_list) == 0:
  304. return (model_input.attn_metadata, None, None)
  305. # Since we are not supporting chunked prefill either the entire
  306. # batch is prefill or it is decode
  307. is_prompt = seq_group_metadata_list[0].is_prompt
  308. # Build encoder inputs
  309. encoder_seq_lens: List[int] = []
  310. if is_prompt:
  311. # Prefill phase.
  312. cross_block_tables = self._empty_int32_tensor().view(
  313. len(seq_group_metadata_list), -1)
  314. # Extract input tokens/positions, cross-attention slot-mapping,
  315. # & seq len from each sequence group metadata
  316. (
  317. encoder_input_tokens,
  318. encoder_input_positions,
  319. cross_slot_mapping,
  320. ) = (
  321. [],
  322. [],
  323. [],
  324. )
  325. for seq_group_metadata in seq_group_metadata_list:
  326. # Build seq lens
  327. seq_len = seq_group_metadata.encoder_seq_data.get_len()
  328. token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
  329. encoder_seq_lens.append(seq_len)
  330. # Build slot mapping
  331. is_profile_run = (seq_group_metadata.block_tables is None)
  332. if is_profile_run:
  333. # During memory profiling, the block tables are not
  334. # initialized yet. In this case, we just use a dummy
  335. # slot mapping.
  336. # In embeddings, the block tables are {seq_id: None}.
  337. cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  338. else:
  339. for i in range(0, seq_len):
  340. block_number = seq_group_metadata.cross_block_table[
  341. i // self.block_size]
  342. block_offset = i % self.block_size
  343. slot = block_number * self.block_size + block_offset
  344. cross_slot_mapping.append(slot)
  345. # Build encoder input tokens
  346. encoder_input_tokens.extend(token_ids)
  347. encoder_input_positions.extend(list(range(0, seq_len)))
  348. # Convert tokens/positions & cross-attention
  349. # slot-mapping to encoder input tensors
  350. encoder_input_tokens_tensor = self._list_to_long_tensor(
  351. encoder_input_tokens)
  352. encoder_input_positions_tensor = self._list_to_long_tensor(
  353. encoder_input_positions)
  354. cross_slot_mapping_tensor = self._list_to_long_tensor(
  355. cross_slot_mapping)
  356. else:
  357. # Decode phase.
  358. encoder_input_tokens_tensor = self._empty_long_tensor()
  359. encoder_input_positions_tensor = self._empty_long_tensor()
  360. cross_slot_mapping_tensor = self._empty_long_tensor()
  361. # Extract cross-attention block tables &
  362. # seq len from each sequence group metadata.
  363. # Cross-attention block tables are empty
  364. # during Aphrodite memory profiling.
  365. cross_block_tables = []
  366. for seq_group_metadata in seq_group_metadata_list:
  367. encoder_seq_lens.append(
  368. seq_group_metadata.encoder_seq_data.get_len())
  369. cross_block_table = seq_group_metadata.cross_block_table
  370. cross_block_tables.append([] if (
  371. cross_block_table is None) else cross_block_table)
  372. # Convert cross-attention block tables to encoder input tensor
  373. cross_block_tables = make_tensor_with_pad(
  374. cross_block_tables,
  375. max_len=max(
  376. len(block_table) for block_table in cross_block_tables),
  377. pad=0,
  378. dtype=torch.int32,
  379. device=self.device,
  380. )
  381. # Compute encoder sequence lengths & encoder
  382. # sequence starting offset tensors
  383. max_encoder_seq_len = max(encoder_seq_lens, default=0)
  384. encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
  385. encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
  386. 1,
  387. dtype=torch.int32,
  388. device=self.device)
  389. torch.cumsum(encoder_seq_lens_tensor,
  390. dim=0,
  391. dtype=encoder_seq_start_loc.dtype,
  392. out=encoder_seq_start_loc[1:])
  393. # Update attention metadata with encoder-oriented attributes
  394. attn_metadata = model_input.attn_metadata
  395. assert attn_metadata is not None
  396. (
  397. attn_metadata.num_encoder_tokens,
  398. attn_metadata.encoder_seq_lens,
  399. attn_metadata.encoder_seq_lens_tensor,
  400. attn_metadata.max_encoder_seq_len,
  401. attn_metadata.cross_slot_mapping,
  402. attn_metadata.cross_block_tables,
  403. ) = (
  404. sum(encoder_seq_lens),
  405. encoder_seq_lens,
  406. encoder_seq_lens_tensor,
  407. max_encoder_seq_len,
  408. cross_slot_mapping_tensor,
  409. cross_block_tables,
  410. )
  411. return (attn_metadata, encoder_input_tokens_tensor,
  412. encoder_input_positions_tensor)