preprocess.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import asyncio
  2. from typing import TYPE_CHECKING, List, Optional, Tuple, Union
  3. from loguru import logger
  4. from typing_extensions import assert_never
  5. from aphrodite.common.config import ModelConfig
  6. from aphrodite.lora.request import LoRARequest
  7. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  8. from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
  9. from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
  10. SingletonPrompt)
  11. from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
  12. if TYPE_CHECKING:
  13. from aphrodite.multimodal import MultiModalDataDict
  14. PromptComponents = Tuple[Optional[str], List[int],
  15. Optional["MultiModalDataDict"]]
  16. DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
  17. Optional["MultiModalDataDict"]]
  18. class InputPreprocessor:
  19. def __init__(
  20. self,
  21. model_config: ModelConfig,
  22. tokenizer: Optional[BaseTokenizerGroup],
  23. ) -> None:
  24. super().__init__()
  25. self.model_config = model_config
  26. self.tokenizer = tokenizer
  27. def get_tokenizer_group(self) -> BaseTokenizerGroup:
  28. if self.tokenizer is None:
  29. raise ValueError("You cannot pass text prompts when "
  30. "`skip_tokenizer_init` is True")
  31. return self.tokenizer
  32. def get_bos_token_id(self,
  33. lora_request: Optional[LoRARequest] = None
  34. ) -> Optional[int]:
  35. if self.tokenizer is None:
  36. logger.warning("Using None for BOS token id because tokenizer "
  37. "is not initialized")
  38. return None
  39. return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
  40. def get_eos_token_id(self,
  41. lora_request: Optional[LoRARequest] = None
  42. ) -> Optional[int]:
  43. if self.tokenizer is None:
  44. logger.warning("Using None for EOS token id because tokenizer "
  45. "is not initialized")
  46. return None
  47. return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
  48. def get_decoder_start_token_id(self) -> Optional[int]:
  49. '''
  50. Obtain the decoder start token id employed by an encoder/decoder
  51. model. Returns None for non-encoder/decoder models or if the
  52. model config is unavailable.
  53. '''
  54. if not self.is_encoder_decoder_model():
  55. logger.warning("Using None for decoder start token id because "
  56. "this is not an encoder/decoder model.")
  57. return None
  58. if (self.model_config is None or self.model_config.hf_config is None):
  59. logger.warning("Using None for decoder start token id because "
  60. "model config is not available.")
  61. return None
  62. dec_start_token_id = getattr(self.model_config.hf_config,
  63. 'decoder_start_token_id', None)
  64. if dec_start_token_id is None:
  65. logger.warning("Falling back on <BOS> for decoder start token id "
  66. "because decoder start token id is not available.")
  67. dec_start_token_id = self.get_bos_token_id()
  68. return dec_start_token_id
  69. def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
  70. '''
  71. Specifically for encoder/decoder models:
  72. generate a default decoder prompt for when
  73. the user specifies only the encoder prompt.
  74. Encoder/decoder models utilize the decoder
  75. prompt in different ways; as new models are
  76. added, it is intended that this function
  77. will be extended to produce differing
  78. default decoder prompts, depending on the
  79. model variety.
  80. Absent a special case, the default behavior
  81. of this method is to mirror the behavior of
  82. the HuggingFace (HF) GenerationMixin for a None
  83. decoder prompt, which is to employ a logit processor
  84. setting to force the first decoded token to be <BOS>.
  85. Here, this behavior is approximated by having the
  86. "default" decoder prompt be <BOS>.
  87. However, it is possible that in the future
  88. other models may have different or more
  89. complex logic for the default decoder prompt.
  90. This motivates having a special helper method
  91. for default decoder prompts.
  92. Returns:
  93. * prompt_token_ids
  94. '''
  95. bos_token_id = self.get_bos_token_id()
  96. assert bos_token_id is not None
  97. return [bos_token_id]
  98. def _prepare_decoder_input_ids_for_generation(
  99. self,
  100. decoder_input_ids: Optional[List[int]],
  101. ) -> List[int]:
  102. """
  103. Prepares `decoder_input_ids` for generation with encoder-decoder models.
  104. Based on
  105. https://github.com/huggingface/transformers/blob/
  106. 4037a2b5b1278736e566aec12e169100275545ea/
  107. src/transformers/generation/utils.py
  108. specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
  109. Arguments:
  110. * decoder_input_ids: input token ids to preprocess
  111. Returns:
  112. * Processed token list
  113. """
  114. decoder_start_token_id = self.get_decoder_start_token_id()
  115. assert decoder_start_token_id is not None
  116. if decoder_input_ids is None:
  117. # no decoder prompt input ->
  118. # use decoder_start_token_id as decoder_input_ids
  119. decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
  120. if (len(decoder_input_ids) == 0
  121. or decoder_input_ids[0] != decoder_start_token_id):
  122. decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
  123. return decoder_input_ids
  124. def _apply_prompt_adapter(
  125. self,
  126. prompt_token_ids: List[int],
  127. prompt_adapter_request: Optional[PromptAdapterRequest],
  128. ) -> List[int]:
  129. if prompt_adapter_request:
  130. prompt_token_ids = (
  131. [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
  132. + prompt_token_ids)
  133. return prompt_token_ids
  134. def _tokenize_prompt(
  135. self,
  136. prompt: str,
  137. request_id: str,
  138. lora_request: Optional[LoRARequest],
  139. ) -> List[int]:
  140. """
  141. Apply the model's tokenizer to a text prompt, returning the
  142. corresponding token IDs.
  143. """
  144. tokenizer = self.get_tokenizer_group()
  145. return tokenizer.encode(request_id=request_id,
  146. prompt=prompt,
  147. lora_request=lora_request)
  148. async def _tokenize_prompt_async(
  149. self,
  150. prompt: str,
  151. request_id: str,
  152. lora_request: Optional[LoRARequest],
  153. ) -> List[int]:
  154. """Async version of :meth:`_tokenize_prompt`."""
  155. tokenizer = self.get_tokenizer_group()
  156. return await tokenizer.encode_async(request_id=request_id,
  157. prompt=prompt,
  158. lora_request=lora_request)
  159. def _extract_prompt_components(
  160. self,
  161. prompt: SingletonPrompt,
  162. request_id: str,
  163. lora_request: Optional[LoRARequest] = None,
  164. ) -> PromptComponents:
  165. '''
  166. Extract the components of any single encoder or decoder input prompt.
  167. Arguments:
  168. * request_id
  169. * prompt: single encoder or decoder input prompt
  170. * lora_request: this is only valid for decoder prompts
  171. Returns:
  172. * prompt
  173. * prompt_token_ids
  174. * multi_modal_data
  175. '''
  176. parsed = parse_singleton_prompt(prompt)
  177. if parsed["type"] == "str":
  178. prompt_text = parsed["content"]
  179. prompt_token_ids = self._tokenize_prompt(
  180. prompt_text,
  181. request_id=request_id,
  182. lora_request=lora_request,
  183. )
  184. multi_modal_data = None
  185. elif parsed["type"] == "tokens":
  186. prompt_text = None
  187. prompt_token_ids = parsed["content"]["prompt_token_ids"]
  188. multi_modal_data = parsed["content"].get("multi_modal_data")
  189. elif parsed["type"] == "text":
  190. prompt_text = parsed["content"]["prompt"]
  191. prompt_token_ids = self._tokenize_prompt(
  192. prompt_text,
  193. request_id=request_id,
  194. lora_request=lora_request,
  195. )
  196. multi_modal_data = parsed["content"].get("multi_modal_data")
  197. else:
  198. assert_never(parsed)
  199. return prompt_text, prompt_token_ids, multi_modal_data
  200. async def _extract_prompt_components_async(
  201. self,
  202. prompt: SingletonPrompt,
  203. request_id: str,
  204. lora_request: Optional[LoRARequest] = None,
  205. ) -> PromptComponents:
  206. """Async version of :meth:`_extract_prompt_components`."""
  207. parsed = parse_singleton_prompt(prompt)
  208. if parsed["type"] == "str":
  209. prompt_text = parsed["content"]
  210. prompt_token_ids = await self._tokenize_prompt_async(
  211. prompt_text,
  212. request_id=request_id,
  213. lora_request=lora_request,
  214. )
  215. multi_modal_data = None
  216. elif parsed["type"] == "tokens":
  217. prompt_text = None
  218. prompt_token_ids = parsed["content"]["prompt_token_ids"]
  219. multi_modal_data = parsed["content"].get("multi_modal_data")
  220. elif parsed["type"] == "text":
  221. prompt_text = parsed["content"]["prompt"]
  222. prompt_token_ids = await self._tokenize_prompt_async(
  223. prompt_text,
  224. request_id=request_id,
  225. lora_request=lora_request,
  226. )
  227. multi_modal_data = parsed["content"].get("multi_modal_data")
  228. else:
  229. assert_never(parsed)
  230. return prompt_text, prompt_token_ids, multi_modal_data
  231. def _build_enc_dec_llm_inputs(
  232. self,
  233. encoder_comps: PromptComponents,
  234. decoder_comps: DecoderPromptComponents,
  235. ) -> EncoderDecoderLLMInputs:
  236. encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
  237. decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
  238. if encoder_mm_data is not None or decoder_mm_data is not None:
  239. raise ValueError("Multi-modal encoder-decoder models are "
  240. "not supported yet")
  241. decoder_prompt_ids = (
  242. self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
  243. return EncoderDecoderLLMInputs(
  244. prompt_token_ids=decoder_prompt_ids,
  245. prompt=decoder_prompt,
  246. encoder_prompt_token_ids=encoder_prompt_ids,
  247. encoder_prompt=encoder_prompt,
  248. )
  249. def _process_encoder_decoder_prompt(
  250. self,
  251. prompt: PromptType,
  252. request_id: str,
  253. ) -> EncoderDecoderLLMInputs:
  254. '''
  255. For encoder/decoder models only:
  256. Process an input prompt into an
  257. :class:`EncoderDecoderLLMInputs` instance.
  258. There are two types of input prompts:
  259. singleton prompts which carry only the
  260. encoder prompt, and explicit encoder/decoder
  261. prompts which carry both the encoder and the
  262. decoder prompts as member variables.
  263. This function handles the following scenarios:
  264. * Singleton encoder prompt: extract encoder prompt
  265. token ids & infer default decoder prompt token ids
  266. * Explicit encoder/decoder prompt: extract encoder
  267. and decoder prompt token ids
  268. Note that for Explicit encoder/decoder prompts,
  269. each sub-prompt (encoder or decoder prompt) can
  270. have any possible singleton type; thus this
  271. method relies on helper functions to obtain
  272. token ids for the sub-prompts.
  273. Arguments:
  274. * prompt: an input prompt
  275. * request_id
  276. Returns:
  277. * :class:`EncoderDecoderLLMInputs` instance
  278. '''
  279. encoder_comps: PromptComponents
  280. decoder_comps: DecoderPromptComponents
  281. if is_explicit_encoder_decoder_prompt(prompt):
  282. encoder_comps = self._extract_prompt_components(
  283. prompt["encoder_prompt"],
  284. request_id=request_id,
  285. )
  286. if (decoder_input := prompt["decoder_prompt"]) is None:
  287. decoder_comps = None, None, None
  288. else:
  289. decoder_comps = self._extract_prompt_components(
  290. decoder_input,
  291. request_id=request_id,
  292. )
  293. else:
  294. encoder_comps = self._extract_prompt_components(
  295. prompt,
  296. request_id=request_id,
  297. )
  298. decoder_comps = None, None, None
  299. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  300. async def _process_encoder_decoder_prompt_async(
  301. self,
  302. prompt: PromptType,
  303. request_id: str,
  304. ) -> EncoderDecoderLLMInputs:
  305. """Async version of :meth:`_process_encoder_decoder_prompt`."""
  306. encoder_comps: PromptComponents
  307. decoder_comps: DecoderPromptComponents
  308. if is_explicit_encoder_decoder_prompt(prompt):
  309. encoder_task = self._extract_prompt_components_async(
  310. prompt["encoder_prompt"],
  311. request_id=request_id,
  312. )
  313. if (decoder_input := prompt["decoder_prompt"]) is None:
  314. encoder_comps = await encoder_task
  315. decoder_comps = None, None, None
  316. else:
  317. decoder_task = self._extract_prompt_components_async(
  318. decoder_input,
  319. request_id=request_id,
  320. )
  321. encoder_comps, decoder_comps = await asyncio.gather(
  322. encoder_task, decoder_task)
  323. else:
  324. encoder_comps = await self._extract_prompt_components_async(
  325. prompt,
  326. request_id=request_id,
  327. )
  328. decoder_comps = None, None, None
  329. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  330. def _build_decoder_only_llm_inputs(
  331. self,
  332. prompt_comps: PromptComponents,
  333. prompt_adapter_request: Optional[PromptAdapterRequest],
  334. ) -> LLMInputs:
  335. prompt, prompt_token_ids, multi_modal_data = prompt_comps
  336. prompt_token_ids = self._apply_prompt_adapter(
  337. prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
  338. return LLMInputs(prompt_token_ids=prompt_token_ids,
  339. prompt=prompt,
  340. multi_modal_data=multi_modal_data)
  341. def _process_decoder_only_prompt(
  342. self,
  343. prompt: SingletonPrompt,
  344. request_id: str,
  345. lora_request: Optional[LoRARequest] = None,
  346. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  347. ) -> LLMInputs:
  348. '''
  349. For decoder-only models:
  350. Process an input prompt into an :class:`LLMInputs` instance.
  351. Arguments:
  352. * prompt: input prompt
  353. * request_id
  354. * lora_request
  355. * prompt_adapter_request
  356. Returns:
  357. * :class:`LLMInputs` instance
  358. '''
  359. prompt_comps = self._extract_prompt_components(
  360. prompt,
  361. request_id=request_id,
  362. lora_request=lora_request,
  363. )
  364. return self._build_decoder_only_llm_inputs(
  365. prompt_comps,
  366. prompt_adapter_request=prompt_adapter_request,
  367. )
  368. async def _process_decoder_only_prompt_async(
  369. self,
  370. prompt: SingletonPrompt,
  371. request_id: str,
  372. lora_request: Optional[LoRARequest] = None,
  373. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  374. ) -> LLMInputs:
  375. """Async version of :meth:`_process_decoder_only_prompt`."""
  376. prompt_comps = await self._extract_prompt_components_async(
  377. prompt,
  378. request_id=request_id,
  379. lora_request=lora_request,
  380. )
  381. return self._build_decoder_only_llm_inputs(
  382. prompt_comps,
  383. prompt_adapter_request=prompt_adapter_request,
  384. )
  385. def preprocess(
  386. self,
  387. prompt: PromptType,
  388. request_id: str,
  389. lora_request: Optional[LoRARequest] = None,
  390. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  391. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  392. """Preprocess the input prompt."""
  393. if self.is_encoder_decoder_model():
  394. # Encoder-decoder model requires special mapping of
  395. # input prompts to encoder & decoder
  396. return self._process_encoder_decoder_prompt(
  397. prompt,
  398. request_id=request_id,
  399. )
  400. if is_explicit_encoder_decoder_prompt(prompt):
  401. raise ValueError("Cannot pass encoder-decoder prompt "
  402. "to decoder-only models")
  403. # Decoder-only operation
  404. return self._process_decoder_only_prompt(
  405. prompt,
  406. request_id=request_id,
  407. lora_request=lora_request,
  408. prompt_adapter_request=prompt_adapter_request,
  409. )
  410. async def preprocess_async(
  411. self,
  412. prompt: PromptType,
  413. request_id: str,
  414. lora_request: Optional[LoRARequest] = None,
  415. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  416. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  417. """Async version of :meth:`preprocess`."""
  418. if self.is_encoder_decoder_model():
  419. # Encoder-decoder model requires special mapping of
  420. # input prompts to encoder & decoder
  421. return await self._process_encoder_decoder_prompt_async(
  422. prompt,
  423. request_id=request_id,
  424. )
  425. if is_explicit_encoder_decoder_prompt(prompt):
  426. raise ValueError("Cannot pass encoder-decoder prompt "
  427. "to decoder-only models")
  428. # Decoder-only operation
  429. return await self._process_decoder_only_prompt_async(
  430. prompt,
  431. request_id=request_id,
  432. lora_request=lora_request,
  433. prompt_adapter_request=prompt_adapter_request,
  434. )
  435. def is_encoder_decoder_model(self):
  436. return self.model_config.is_encoder_decoder_model