preprocess.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. import asyncio
  2. from typing import TYPE_CHECKING, List, Optional, Tuple, Union
  3. from loguru import logger
  4. from typing_extensions import assert_never
  5. from aphrodite.common.config import ModelConfig
  6. from aphrodite.lora.request import LoRARequest
  7. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  8. from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
  9. from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
  10. SingletonPromptInputs)
  11. from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
  12. if TYPE_CHECKING:
  13. from aphrodite.multimodal import MultiModalDataDict
  14. PromptComponents = Tuple[
  15. Optional[str], List[int], Optional["MultiModalDataDict"]
  16. ]
  17. DecoderPromptComponents = Tuple[
  18. Optional[str], Optional[List[int]], Optional["MultiModalDataDict"]
  19. ]
  20. class InputPreprocessor:
  21. def __init__(
  22. self,
  23. model_config: ModelConfig,
  24. tokenizer: Optional[BaseTokenizerGroup],
  25. ) -> None:
  26. super().__init__()
  27. self.model_config = model_config
  28. self.tokenizer = tokenizer
  29. def get_tokenizer_group(self) -> BaseTokenizerGroup:
  30. if self.tokenizer is None:
  31. raise ValueError(
  32. "You cannot pass text prompts when "
  33. "`skip_tokenizer_init` is True"
  34. )
  35. return self.tokenizer
  36. def get_bos_token_id(
  37. self, lora_request: Optional[LoRARequest] = None
  38. ) -> Optional[int]:
  39. if self.tokenizer is None:
  40. logger.warning(
  41. "Using None for BOS token id because tokenizer "
  42. "is not initialized"
  43. )
  44. return None
  45. return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
  46. def get_eos_token_id(
  47. self, lora_request: Optional[LoRARequest] = None
  48. ) -> Optional[int]:
  49. if self.tokenizer is None:
  50. logger.warning(
  51. "Using None for EOS token id because tokenizer "
  52. "is not initialized"
  53. )
  54. return None
  55. return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
  56. def get_decoder_start_token_id(self) -> Optional[int]:
  57. """
  58. Obtain the decoder start token id employed by an encoder/decoder
  59. model. Returns None for non-encoder/decoder models or if the
  60. model config is unavailable.
  61. """
  62. if not self.is_encoder_decoder_model():
  63. logger.warning(
  64. "Using None for decoder start token id because "
  65. "this is not an encoder/decoder model."
  66. )
  67. return None
  68. if self.model_config is None or self.model_config.hf_config is None:
  69. logger.warning(
  70. "Using None for decoder start token id because "
  71. "model config is not available."
  72. )
  73. return None
  74. dec_start_token_id = getattr(
  75. self.model_config.hf_config, "decoder_start_token_id", None
  76. )
  77. if dec_start_token_id is None:
  78. logger.warning(
  79. "Falling back on <BOS> for decoder start token id "
  80. "because decoder start token id is not available."
  81. )
  82. dec_start_token_id = self.get_bos_token_id()
  83. return dec_start_token_id
  84. def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
  85. """
  86. Specifically for encoder/decoder models:
  87. generate a default decoder prompt for when
  88. the user specifies only the encoder prompt.
  89. Encoder/decoder models utilize the decoder
  90. prompt in different ways; as new models are
  91. added, it is intended that this function
  92. will be extended to produce differing
  93. default decoder prompts, depending on the
  94. model variety.
  95. Absent a special case, the default behavior
  96. of this method is to mirror the behavior of
  97. the HuggingFace (HF) GenerationMixin for a None
  98. decoder prompt, which is to employ a logit processor
  99. setting to force the first decoded token to be <BOS>.
  100. Here, this behavior is approximated by having the
  101. "default" decoder prompt be <BOS>.
  102. However, it is possible that in the future
  103. other models may have different or more
  104. complex logic for the default decoder prompt.
  105. This motivates having a special helper method
  106. for default decoder prompts.
  107. Returns:
  108. * prompt_token_ids
  109. """
  110. bos_token_id = self.get_bos_token_id()
  111. assert bos_token_id is not None
  112. return [bos_token_id]
  113. def _prepare_decoder_input_ids_for_generation(
  114. self,
  115. decoder_input_ids: Optional[List[int]],
  116. ) -> List[int]:
  117. """
  118. Prepares `decoder_input_ids` for generation with encoder-decoder models.
  119. Based on
  120. https://github.com/huggingface/transformers/blob/
  121. 4037a2b5b1278736e566aec12e169100275545ea/
  122. src/transformers/generation/utils.py
  123. specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
  124. Arguments:
  125. * decoder_input_ids: input token ids to preprocess
  126. Returns:
  127. * Processed token list
  128. """
  129. decoder_start_token_id = self.get_decoder_start_token_id()
  130. assert decoder_start_token_id is not None
  131. if decoder_input_ids is None:
  132. # no decoder prompt input ->
  133. # use decoder_start_token_id as decoder_input_ids
  134. decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
  135. if (
  136. len(decoder_input_ids) == 0
  137. or decoder_input_ids[0] != decoder_start_token_id
  138. ):
  139. decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
  140. return decoder_input_ids
  141. def _apply_prompt_adapter(
  142. self,
  143. prompt_token_ids: List[int],
  144. prompt_adapter_request: Optional[PromptAdapterRequest],
  145. ) -> List[int]:
  146. if prompt_adapter_request:
  147. prompt_token_ids = (
  148. [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
  149. + prompt_token_ids
  150. )
  151. return prompt_token_ids
  152. def _tokenize_prompt(
  153. self,
  154. prompt: str,
  155. request_id: str,
  156. lora_request: Optional[LoRARequest],
  157. ) -> List[int]:
  158. """
  159. Apply the model's tokenizer to a text prompt, returning the
  160. corresponding token IDs.
  161. """
  162. tokenizer = self.get_tokenizer_group()
  163. return tokenizer.encode(
  164. request_id=request_id, prompt=prompt, lora_request=lora_request
  165. )
  166. async def _tokenize_prompt_async(
  167. self,
  168. prompt: str,
  169. request_id: str,
  170. lora_request: Optional[LoRARequest],
  171. ) -> List[int]:
  172. """Async version of :meth:`_tokenize_prompt`."""
  173. tokenizer = self.get_tokenizer_group()
  174. return await tokenizer.encode_async(
  175. request_id=request_id, prompt=prompt, lora_request=lora_request
  176. )
  177. def _extract_prompt_components(
  178. self,
  179. inputs: SingletonPromptInputs,
  180. request_id: str,
  181. lora_request: Optional[LoRARequest] = None,
  182. ) -> PromptComponents:
  183. """
  184. Extract the components of any single encoder or decoder input prompt.
  185. Arguments:
  186. * request_id
  187. * inputs: single encoder or decoder input prompt
  188. * lora_request: this is only valid for decoder prompts
  189. Returns:
  190. * prompt
  191. * prompt_token_ids
  192. * multi_modal_data
  193. """
  194. parsed = parse_singleton_prompt(inputs)
  195. if parsed["type"] == "str":
  196. prompt = parsed["content"]
  197. prompt_token_ids = self._tokenize_prompt(
  198. prompt,
  199. request_id=request_id,
  200. lora_request=lora_request,
  201. )
  202. multi_modal_data = None
  203. elif parsed["type"] == "tokens":
  204. prompt = None
  205. prompt_token_ids = parsed["content"]["prompt_token_ids"]
  206. multi_modal_data = parsed["content"].get("multi_modal_data")
  207. elif parsed["type"] == "text":
  208. prompt = parsed["content"]["prompt"]
  209. prompt_token_ids = self._tokenize_prompt(
  210. prompt,
  211. request_id=request_id,
  212. lora_request=lora_request,
  213. )
  214. multi_modal_data = parsed["content"].get("multi_modal_data")
  215. else:
  216. assert_never(parsed)
  217. return prompt, prompt_token_ids, multi_modal_data
  218. async def _extract_prompt_components_async(
  219. self,
  220. inputs: SingletonPromptInputs,
  221. request_id: str,
  222. lora_request: Optional[LoRARequest] = None,
  223. ) -> PromptComponents:
  224. """Async version of :meth:`_extract_prompt_components`."""
  225. parsed = parse_singleton_prompt(inputs)
  226. if parsed["type"] == "str":
  227. prompt = parsed["content"]
  228. prompt_token_ids = await self._tokenize_prompt_async(
  229. prompt,
  230. request_id=request_id,
  231. lora_request=lora_request,
  232. )
  233. multi_modal_data = None
  234. elif parsed["type"] == "tokens":
  235. prompt = None
  236. prompt_token_ids = parsed["content"]["prompt_token_ids"]
  237. multi_modal_data = parsed["content"].get("multi_modal_data")
  238. elif parsed["type"] == "text":
  239. prompt = parsed["content"]["prompt"]
  240. prompt_token_ids = await self._tokenize_prompt_async(
  241. prompt,
  242. request_id=request_id,
  243. lora_request=lora_request,
  244. )
  245. multi_modal_data = parsed["content"].get("multi_modal_data")
  246. else:
  247. assert_never(parsed)
  248. return prompt, prompt_token_ids, multi_modal_data
  249. def _build_enc_dec_llm_inputs(
  250. self,
  251. encoder_comps: PromptComponents,
  252. decoder_comps: DecoderPromptComponents,
  253. ) -> EncoderDecoderLLMInputs:
  254. encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
  255. decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
  256. if encoder_mm_data is not None or decoder_mm_data is not None:
  257. raise ValueError(
  258. "Multi-modal encoder-decoder models are " "not supported yet"
  259. )
  260. decoder_prompt_ids = self._prepare_decoder_input_ids_for_generation(
  261. decoder_prompt_ids
  262. )
  263. return EncoderDecoderLLMInputs(
  264. prompt_token_ids=decoder_prompt_ids,
  265. prompt=decoder_prompt,
  266. encoder_prompt_token_ids=encoder_prompt_ids,
  267. encoder_prompt=encoder_prompt,
  268. )
  269. def _process_encoder_decoder_prompt(
  270. self,
  271. inputs: PromptInputs,
  272. request_id: str,
  273. ) -> EncoderDecoderLLMInputs:
  274. """
  275. For encoder/decoder models only:
  276. Process an input prompt into an
  277. :class:`EncoderDecoderLLMInputs` instance.
  278. There are two types of input prompts:
  279. singleton prompts which carry only the
  280. encoder prompt, and explicit encoder/decoder
  281. prompts which carry both the encoder and the
  282. decoder prompts as member variables.
  283. This function handles the following scenarios:
  284. * Singleton encoder prompt: extract encoder prompt
  285. token ids & infer default decoder prompt token ids
  286. * Explicit encoder/decoder prompt: extract encoder
  287. and decoder prompt token ids
  288. Note that for Explicit encoder/decoder prompts,
  289. each sub-prompt (encoder or decoder prompt) can
  290. have any possible singleton type; thus this
  291. method relies on helper functions to obtain
  292. token ids for the sub-prompts.
  293. Arguments:
  294. * inputs: an input prompt
  295. * request_id
  296. Returns:
  297. * :class:`EncoderDecoderLLMInputs` instance
  298. """
  299. encoder_comps: PromptComponents
  300. decoder_comps: DecoderPromptComponents
  301. if is_explicit_encoder_decoder_prompt(inputs):
  302. encoder_comps = self._extract_prompt_components(
  303. inputs["encoder_prompt"],
  304. request_id=request_id,
  305. )
  306. if (decoder_input := inputs["decoder_prompt"]) is None:
  307. decoder_comps = None, None, None
  308. else:
  309. decoder_comps = self._extract_prompt_components(
  310. decoder_input,
  311. request_id=request_id,
  312. )
  313. else:
  314. encoder_comps = self._extract_prompt_components(
  315. inputs,
  316. request_id=request_id,
  317. )
  318. decoder_comps = None, None, None
  319. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  320. async def _process_encoder_decoder_prompt_async(
  321. self,
  322. inputs: PromptInputs,
  323. request_id: str,
  324. ) -> EncoderDecoderLLMInputs:
  325. """Async version of :meth:`_process_encoder_decoder_prompt`."""
  326. encoder_comps: PromptComponents
  327. decoder_comps: DecoderPromptComponents
  328. if is_explicit_encoder_decoder_prompt(inputs):
  329. encoder_task = self._extract_prompt_components_async(
  330. inputs["encoder_prompt"],
  331. request_id=request_id,
  332. )
  333. if (decoder_input := inputs["decoder_prompt"]) is None:
  334. encoder_comps = await encoder_task
  335. decoder_comps = None, None, None
  336. else:
  337. decoder_task = self._extract_prompt_components_async(
  338. decoder_input,
  339. request_id=request_id,
  340. )
  341. encoder_comps, decoder_comps = await asyncio.gather(
  342. encoder_task, decoder_task
  343. )
  344. else:
  345. encoder_comps = await self._extract_prompt_components_async(
  346. inputs,
  347. request_id=request_id,
  348. )
  349. decoder_comps = None, None, None
  350. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  351. def _build_decoder_only_llm_inputs(
  352. self,
  353. prompt_comps: PromptComponents,
  354. prompt_adapter_request: Optional[PromptAdapterRequest],
  355. ) -> LLMInputs:
  356. prompt, prompt_token_ids, multi_modal_data = prompt_comps
  357. prompt_token_ids = self._apply_prompt_adapter(
  358. prompt_token_ids, prompt_adapter_request=prompt_adapter_request
  359. )
  360. return LLMInputs(
  361. prompt_token_ids=prompt_token_ids,
  362. prompt=prompt,
  363. multi_modal_data=multi_modal_data,
  364. )
  365. def _process_decoder_only_prompt(
  366. self,
  367. inputs: SingletonPromptInputs,
  368. request_id: str,
  369. lora_request: Optional[LoRARequest] = None,
  370. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  371. ) -> LLMInputs:
  372. """
  373. For decoder-only models:
  374. Process an input prompt into an :class:`LLMInputs` instance.
  375. Arguments:
  376. * inputs: input prompt
  377. * request_id
  378. * lora_request
  379. * prompt_adapter_request
  380. Returns:
  381. * :class:`LLMInputs` instance
  382. """
  383. prompt_comps = self._extract_prompt_components(
  384. inputs,
  385. request_id=request_id,
  386. lora_request=lora_request,
  387. )
  388. return self._build_decoder_only_llm_inputs(
  389. prompt_comps,
  390. prompt_adapter_request=prompt_adapter_request,
  391. )
  392. async def _process_decoder_only_prompt_async(
  393. self,
  394. inputs: SingletonPromptInputs,
  395. request_id: str,
  396. lora_request: Optional[LoRARequest] = None,
  397. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  398. ) -> LLMInputs:
  399. """Async version of :meth:`_process_decoder_only_prompt`."""
  400. prompt_comps = await self._extract_prompt_components_async(
  401. inputs,
  402. request_id=request_id,
  403. lora_request=lora_request,
  404. )
  405. return self._build_decoder_only_llm_inputs(
  406. prompt_comps,
  407. prompt_adapter_request=prompt_adapter_request,
  408. )
  409. def preprocess(
  410. self,
  411. inputs: PromptInputs,
  412. request_id: str,
  413. lora_request: Optional[LoRARequest] = None,
  414. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  415. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  416. """Preprocess the input prompt."""
  417. if self.is_encoder_decoder_model():
  418. # Encoder-decoder model requires special mapping of
  419. # input prompts to encoder & decoder
  420. return self._process_encoder_decoder_prompt(
  421. inputs,
  422. request_id=request_id,
  423. )
  424. if is_explicit_encoder_decoder_prompt(inputs):
  425. raise ValueError(
  426. "Cannot pass encoder-decoder prompt " "to decoder-only models"
  427. )
  428. # Decoder-only operation
  429. return self._process_decoder_only_prompt(
  430. inputs,
  431. request_id=request_id,
  432. lora_request=lora_request,
  433. prompt_adapter_request=prompt_adapter_request,
  434. )
  435. async def preprocess_async(
  436. self,
  437. inputs: PromptInputs,
  438. request_id: str,
  439. lora_request: Optional[LoRARequest] = None,
  440. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  441. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  442. """Async version of :meth:`preprocess`."""
  443. if self.is_encoder_decoder_model():
  444. # Encoder-decoder model requires special mapping of
  445. # input prompts to encoder & decoder
  446. return await self._process_encoder_decoder_prompt_async(
  447. inputs,
  448. request_id=request_id,
  449. )
  450. if is_explicit_encoder_decoder_prompt(inputs):
  451. raise ValueError(
  452. "Cannot pass encoder-decoder prompt " "to decoder-only models"
  453. )
  454. # Decoder-only operation
  455. return await self._process_decoder_only_prompt_async(
  456. inputs,
  457. request_id=request_id,
  458. lora_request=lora_request,
  459. prompt_adapter_request=prompt_adapter_request,
  460. )
  461. def is_encoder_decoder_model(self):
  462. return self.model_config.is_encoder_decoder_model