llm.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. from contextlib import contextmanager
  2. from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
  3. overload)
  4. from tqdm import tqdm
  5. from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
  6. from aphrodite.common.pooling_params import PoolingParams
  7. from aphrodite.common.sampling_params import RequestOutputKind, SamplingParams
  8. from aphrodite.common.utils import Counter, deprecate_kwargs, is_list_of
  9. from aphrodite.endpoints.chat_utils import (ChatCompletionMessageParam,
  10. apply_hf_chat_template,
  11. apply_mistral_chat_template,
  12. parse_chat_messages)
  13. from aphrodite.engine.aphrodite_engine import AphroditeEngine
  14. from aphrodite.engine.args_tools import EngineArgs
  15. from aphrodite.inputs import PromptType, TextPrompt, TokensPrompt
  16. from aphrodite.inputs.parse import parse_and_batch_prompt
  17. from aphrodite.lora.request import LoRARequest
  18. from aphrodite.modeling.guided_decoding import (
  19. GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
  20. from aphrodite.modeling.guided_decoding.guided_fields import LLMGuidedOptions
  21. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  22. from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
  23. MistralTokenizer,
  24. get_cached_tokenizer)
  25. from aphrodite.transformers_utils.tokenizer_group import TokenizerGroup
  26. class LLM:
  27. """An LLM for generating texts from given prompts and sampling parameters.
  28. This class includes a tokenizer, a language model (possibly distributed
  29. across multiple GPUs), and GPU memory space allocated for intermediate
  30. states (aka KV cache). Given a batch of prompts and sampling parameters,
  31. this class generates texts from the model, using an intelligent batching
  32. mechanism and efficient memory management.
  33. Args:
  34. model: The name or path of a HuggingFace Transformers model.
  35. tokenizer: The name or path of a HuggingFace Transformers tokenizer.
  36. tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
  37. if available, and "slow" will always use the slow tokenizer.
  38. skip_tokenizer_init: If true, skip initialization of tokenizer and
  39. detokenizer. Expect valid prompt_token_ids and None for prompt
  40. from the input.
  41. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  42. downloading the model and tokenizer.
  43. tensor_parallel_size: The number of GPUs to use for distributed
  44. execution with tensor parallelism.
  45. dtype: The data type for the model weights and activations. Currently,
  46. we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
  47. the `torch_dtype` attribute specified in the model config file.
  48. However, if the `torch_dtype` in the config is `float32`, we will
  49. use `float16` instead.
  50. quantization: The method used to quantize the model weights. Currently,
  51. we support "awq", "gptq", and "fp8" (experimental).
  52. If None, we first check the `quantization_config` attribute in the
  53. model config file. If that is None, we assume the model weights are
  54. not quantized and use `dtype` to determine the data type of
  55. the weights.
  56. revision: The specific model version to use. It can be a branch name,
  57. a tag name, or a commit id.
  58. tokenizer_revision: The specific tokenizer version to use. It can be a
  59. branch name, a tag name, or a commit id.
  60. seed: The seed to initialize the random number generator for sampling.
  61. gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
  62. reserve for the model weights, activations, and KV cache. Higher
  63. values will increase the KV cache size and thus improve the model's
  64. throughput. However, if the value is too high, it may cause out-of-
  65. memory (OOM) errors.
  66. swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
  67. This can be used for temporarily storing the states of the requests
  68. when their `best_of` sampling parameters are larger than 1. If all
  69. requests will have `best_of=1`, you can safely set this to 0.
  70. Otherwise, too small values may cause out-of-memory (OOM) errors.
  71. cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
  72. the model weights. This virtually increases the GPU memory space
  73. you can use to hold the model weights, at the cost of CPU-GPU data
  74. transfer for every forward pass.
  75. enforce_eager: Whether to enforce eager execution. If True, we will
  76. disable CUDA graph and always execute the model in eager mode.
  77. If False, we will use CUDA graph and eager execution in hybrid.
  78. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  79. When a sequence has context length larger than this, we fall back
  80. to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
  81. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
  82. When a sequence has context length larger than this, we fall back
  83. to eager mode. Additionally for encoder-decoder models, if the
  84. sequence length of the encoder input is larger than this, we fall
  85. back to the eager mode.
  86. disable_custom_all_reduce: See ParallelConfig
  87. **kwargs: Arguments for :class:`~aphrodite.EngineArgs`. (See
  88. :ref:`engine_args`)
  89. Note:
  90. This class is intended to be used for offline inference. For online
  91. serving, use the :class:`~aphrodite.AsyncAphrodite` class instead.
  92. """
  93. DEPRECATE_LEGACY: ClassVar[bool] = False
  94. """A flag to toggle whether to deprecate the legacy generate/encode API."""
  95. @classmethod
  96. @contextmanager
  97. def deprecate_legacy_api(cls):
  98. cls.DEPRECATE_LEGACY = True
  99. yield
  100. cls.DEPRECATE_LEGACY = False
  101. def __init__(
  102. self,
  103. model: str,
  104. tokenizer: Optional[str] = None,
  105. tokenizer_mode: str = "auto",
  106. skip_tokenizer_init: bool = False,
  107. trust_remote_code: bool = False,
  108. tensor_parallel_size: int = 1,
  109. dtype: str = "auto",
  110. quantization: Optional[str] = None,
  111. revision: Optional[str] = None,
  112. tokenizer_revision: Optional[str] = None,
  113. seed: int = 0,
  114. gpu_memory_utilization: float = 0.9,
  115. swap_space: float = 4,
  116. cpu_offload_gb: float = 0,
  117. enforce_eager: Optional[bool] = None,
  118. max_context_len_to_capture: Optional[int] = None,
  119. max_seq_len_to_capture: int = 8192,
  120. disable_custom_all_reduce: bool = False,
  121. disable_async_output_proc: bool = False,
  122. **kwargs,
  123. ) -> None:
  124. '''
  125. LLM constructor.
  126. Note: if enforce_eager is unset (enforce_eager is None)
  127. it defaults to False.
  128. '''
  129. if "disable_log_stats" not in kwargs:
  130. kwargs["disable_log_stats"] = True
  131. removed_vision_keys = (
  132. "image_token_id",
  133. "image_feature_size",
  134. "image_input_shape",
  135. "image_input_type",
  136. )
  137. if any(k in kwargs for k in removed_vision_keys):
  138. raise TypeError(
  139. "There is no need to pass vision-related arguments anymore.")
  140. engine_args = EngineArgs(
  141. model=model,
  142. tokenizer=tokenizer,
  143. tokenizer_mode=tokenizer_mode,
  144. skip_tokenizer_init=skip_tokenizer_init,
  145. trust_remote_code=trust_remote_code,
  146. tensor_parallel_size=tensor_parallel_size,
  147. dtype=dtype,
  148. quantization=quantization,
  149. revision=revision,
  150. tokenizer_revision=tokenizer_revision,
  151. seed=seed,
  152. gpu_memory_utilization=gpu_memory_utilization,
  153. swap_space=swap_space,
  154. cpu_offload_gb=cpu_offload_gb,
  155. enforce_eager=enforce_eager,
  156. max_context_len_to_capture=max_context_len_to_capture,
  157. max_seq_len_to_capture=max_seq_len_to_capture,
  158. disable_custom_all_reduce=disable_custom_all_reduce,
  159. disable_async_output_proc=disable_async_output_proc,
  160. **kwargs,
  161. )
  162. self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
  163. self.request_counter = Counter()
  164. def get_tokenizer(self) -> AnyTokenizer:
  165. return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
  166. def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
  167. tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
  168. # While CachedTokenizer is dynamic, have no choice but
  169. # compare class name. Misjudgment will arise from
  170. # user-defined tokenizer started with 'Cached'
  171. if tokenizer.__class__.__name__.startswith("Cached"):
  172. tokenizer_group.tokenizer = tokenizer
  173. else:
  174. tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
  175. @overload # LEGACY: single (prompt + optional token ids)
  176. def generate(
  177. self,
  178. prompts: str,
  179. sampling_params: Optional[Union[SamplingParams,
  180. List[SamplingParams]]] = None,
  181. prompt_token_ids: Optional[List[int]] = None,
  182. use_tqdm: bool = True,
  183. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  184. ) -> List[RequestOutput]:
  185. ...
  186. @overload # LEGACY: multi (prompt + optional token ids)
  187. def generate(
  188. self,
  189. prompts: List[str],
  190. sampling_params: Optional[Union[SamplingParams,
  191. List[SamplingParams]]] = None,
  192. prompt_token_ids: Optional[List[List[int]]] = None,
  193. use_tqdm: bool = True,
  194. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  195. ) -> List[RequestOutput]:
  196. ...
  197. @overload # LEGACY: single (token ids + optional prompt)
  198. def generate(
  199. self,
  200. prompts: Optional[str] = None,
  201. sampling_params: Optional[Union[SamplingParams,
  202. List[SamplingParams]]] = None,
  203. *,
  204. prompt_token_ids: List[int],
  205. use_tqdm: bool = True,
  206. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  207. ) -> List[RequestOutput]:
  208. ...
  209. @overload # LEGACY: multi (token ids + optional prompt)
  210. def generate(
  211. self,
  212. prompts: Optional[List[str]] = None,
  213. sampling_params: Optional[Union[SamplingParams,
  214. List[SamplingParams]]] = None,
  215. *,
  216. prompt_token_ids: List[List[int]],
  217. use_tqdm: bool = True,
  218. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  219. ) -> List[RequestOutput]:
  220. ...
  221. @overload # LEGACY: single or multi token ids [pos-only]
  222. def generate(
  223. self,
  224. prompts: None,
  225. sampling_params: None,
  226. prompt_token_ids: Union[List[int], List[List[int]]],
  227. use_tqdm: bool = True,
  228. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  229. ) -> List[RequestOutput]:
  230. ...
  231. @overload
  232. def generate(
  233. self,
  234. prompts: Union[PromptType, Sequence[PromptType]],
  235. /,
  236. *,
  237. sampling_params: Optional[Union[SamplingParams,
  238. Sequence[SamplingParams]]] = None,
  239. use_tqdm: bool = True,
  240. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  241. ) -> List[RequestOutput]:
  242. ...
  243. @deprecate_kwargs(
  244. "prompts",
  245. "prompt_token_ids",
  246. is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
  247. additional_message="Please use the 'inputs' parameter instead.",
  248. )
  249. def generate(
  250. self,
  251. prompts: Union[Union[PromptType, Sequence[PromptType]],
  252. Optional[Union[str, List[str]]]] = None,
  253. sampling_params: Optional[Union[SamplingParams,
  254. Sequence[SamplingParams]]] = None,
  255. prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
  256. use_tqdm: bool = True,
  257. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  258. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  259. guided_options_request: Optional[Union[LLMGuidedOptions,
  260. GuidedDecodingRequest]] = None
  261. ) -> List[RequestOutput]:
  262. """Generates the completions for the input prompts.
  263. This class automatically batches the given prompts, considering
  264. the memory constraint. For the best performance, put all of your prompts
  265. into a single list and pass it to this method.
  266. Args:
  267. prompts: The prompts to the LLM. You may pass a sequence of prompts
  268. for batch inference. See :class:`~aphrodite.inputs.PromptType`
  269. for more details about the format of each prompts.
  270. sampling_params: The sampling parameters for text generation. If
  271. None, we use the default sampling parameters.
  272. When it is a single value, it is applied to every prompt.
  273. When it is a list, the list must have the same length as the
  274. prompts and it is paired one by one with the prompt.
  275. use_tqdm: Whether to use tqdm to display the progress bar.
  276. lora_request: LoRA request to use for generation, if any.
  277. prompt_adapter_request: Prompt Adapter request to use for
  278. generation, if any.
  279. Returns:
  280. A list of ``RequestOutput`` objects containing the
  281. generated completions in the same order as the input prompts.
  282. Note:
  283. Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
  284. considered legacy and may be deprecated in the future. You should
  285. instead pass them via the ``inputs`` parameter.
  286. """
  287. if self.llm_engine.model_config.embedding_mode:
  288. raise ValueError(
  289. "LLM.generate() is only supported for (conditional) generation "
  290. "models (XForCausalLM, XForConditionalGeneration).")
  291. if prompt_token_ids is not None:
  292. parsed_prompts = self._convert_v1_inputs(
  293. prompts=cast(Optional[Union[str, List[str]]], prompts),
  294. prompt_token_ids=prompt_token_ids,
  295. )
  296. else:
  297. parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
  298. prompts)
  299. if isinstance(guided_options_request, dict):
  300. if len(guided_options_request) > 1:
  301. raise ValueError(
  302. "You can only use one guided decoding but multiple is "
  303. f"specified: {guided_options_request}")
  304. guided_options_request = GuidedDecodingRequest(
  305. **guided_options_request)
  306. if sampling_params is None:
  307. # Use default sampling params.
  308. sampling_params = SamplingParams()
  309. self._validate_and_add_requests(
  310. prompts=parsed_prompts,
  311. params=sampling_params,
  312. lora_request=lora_request,
  313. prompt_adapter_request=prompt_adapter_request,
  314. guided_options=guided_options_request)
  315. outputs = self._run_engine(use_tqdm=use_tqdm)
  316. return AphroditeEngine.validate_outputs(outputs, RequestOutput)
  317. def chat(
  318. self,
  319. messages: List[ChatCompletionMessageParam],
  320. sampling_params: Optional[Union[SamplingParams,
  321. List[SamplingParams]]] = None,
  322. use_tqdm: bool = True,
  323. lora_request: Optional[LoRARequest] = None,
  324. chat_template: Optional[str] = None,
  325. add_generation_prompt: bool = True,
  326. tools: Optional[List[Dict[str, Any]]] = None,
  327. ) -> List[RequestOutput]:
  328. """
  329. Generate responses for a chat conversation.
  330. The chat conversation is converted into a text prompt using the
  331. tokenizer and calls the :meth:`generate` method to generate the
  332. responses.
  333. Multi-modal inputs can be passed in the same way you would pass them
  334. to the OpenAI API.
  335. Args:
  336. messages: A single conversation represented as a list of messages.
  337. Each message is a dictionary with 'role' and 'content' keys.
  338. sampling_params: The sampling parameters for text generation.
  339. If None, we use the default sampling parameters. When it
  340. is a single value, it is applied to every prompt. When it
  341. is a list, the list must have the same length as the
  342. prompts and it is paired one by one with the prompt.
  343. use_tqdm: Whether to use tqdm to display the progress bar.
  344. lora_request: LoRA request to use for generation, if any.
  345. chat_template: The template to use for structuring the chat.
  346. If not provided, the model's default chat template will be used.
  347. add_generation_prompt: If True, adds a generation template
  348. to each message.
  349. Returns:
  350. A list of ``RequestOutput`` objects containing the generated
  351. responses in the same order as the input messages.
  352. """
  353. tokenizer = self.get_tokenizer()
  354. model_config = self.llm_engine.get_model_config()
  355. conversation, mm_data = parse_chat_messages(messages, model_config,
  356. tokenizer)
  357. prompt_data: Union[str, List[int]]
  358. if isinstance(tokenizer, MistralTokenizer):
  359. prompt_data = apply_mistral_chat_template(
  360. tokenizer,
  361. messages=messages,
  362. chat_template=chat_template,
  363. add_generation_prompt=add_generation_prompt,
  364. tools=tools,
  365. )
  366. else:
  367. prompt_data = apply_hf_chat_template(
  368. tokenizer,
  369. conversation=conversation,
  370. chat_template=chat_template,
  371. add_generation_prompt=add_generation_prompt,
  372. tools=tools,
  373. )
  374. prompt: PromptType
  375. if is_list_of(prompt_data, int):
  376. prompt = TokensPrompt(prompt_token_ids=prompt_data)
  377. else:
  378. prompt = TextPrompt(prompt=prompt_data)
  379. if mm_data is not None:
  380. prompt["multi_modal_data"] = mm_data
  381. return self.generate(
  382. prompt,
  383. sampling_params=sampling_params,
  384. use_tqdm=use_tqdm,
  385. lora_request=lora_request,
  386. )
  387. @overload # LEGACY: single (prompt + optional token ids)
  388. def encode(
  389. self,
  390. prompts: str,
  391. pooling_params: Optional[Union[PoolingParams,
  392. Sequence[PoolingParams]]] = None,
  393. prompt_token_ids: Optional[List[int]] = None,
  394. use_tqdm: bool = True,
  395. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  396. ) -> List[EmbeddingRequestOutput]:
  397. ...
  398. @overload # LEGACY: multi (prompt + optional token ids)
  399. def encode(
  400. self,
  401. prompts: List[str],
  402. pooling_params: Optional[Union[PoolingParams,
  403. Sequence[PoolingParams]]] = None,
  404. prompt_token_ids: Optional[List[List[int]]] = None,
  405. use_tqdm: bool = True,
  406. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  407. ) -> List[EmbeddingRequestOutput]:
  408. ...
  409. @overload # LEGACY: single (token ids + optional prompt)
  410. def encode(
  411. self,
  412. prompts: Optional[str] = None,
  413. pooling_params: Optional[Union[PoolingParams,
  414. Sequence[PoolingParams]]] = None,
  415. *,
  416. prompt_token_ids: List[int],
  417. use_tqdm: bool = True,
  418. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  419. ) -> List[EmbeddingRequestOutput]:
  420. ...
  421. @overload # LEGACY: multi (token ids + optional prompt)
  422. def encode(
  423. self,
  424. prompts: Optional[List[str]] = None,
  425. pooling_params: Optional[Union[PoolingParams,
  426. Sequence[PoolingParams]]] = None,
  427. *,
  428. prompt_token_ids: List[List[int]],
  429. use_tqdm: bool = True,
  430. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  431. ) -> List[EmbeddingRequestOutput]:
  432. ...
  433. @overload # LEGACY: single or multi token ids [pos-only]
  434. def encode(
  435. self,
  436. prompts: None,
  437. pooling_params: None,
  438. prompt_token_ids: Union[List[int], List[List[int]]],
  439. use_tqdm: bool = True,
  440. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  441. ) -> List[EmbeddingRequestOutput]:
  442. ...
  443. @overload
  444. def encode(
  445. self,
  446. prompts: Union[PromptType, Sequence[PromptType]],
  447. /,
  448. *,
  449. pooling_params: Optional[Union[PoolingParams,
  450. Sequence[PoolingParams]]] = None,
  451. use_tqdm: bool = True,
  452. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  453. ) -> List[EmbeddingRequestOutput]:
  454. ...
  455. @deprecate_kwargs(
  456. "prompts",
  457. "prompt_token_ids",
  458. is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
  459. additional_message="Please use the 'inputs' parameter instead.",
  460. )
  461. def encode(
  462. self,
  463. prompts: Union[Union[PromptType, Sequence[PromptType]],
  464. Optional[Union[str, List[str]]]] = None,
  465. pooling_params: Optional[Union[PoolingParams,
  466. Sequence[PoolingParams]]] = None,
  467. prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
  468. use_tqdm: bool = True,
  469. lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
  470. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  471. ) -> List[EmbeddingRequestOutput]:
  472. """Generates the completions for the input prompts.
  473. This class automatically batches the given prompts, considering
  474. the memory constraint. For the best performance, put all of your prompts
  475. into a single list and pass it to this method.
  476. Args:
  477. prompts: The prompts to the LLM. You may pass a sequence of prompts
  478. for batch inference. See :class:`~aphrodite.inputs.PromptType`
  479. for more details about the format of each prompts.
  480. pooling_params: The pooling parameters for pooling. If None, we
  481. use the default pooling parameters.
  482. use_tqdm: Whether to use tqdm to display the progress bar.
  483. lora_request: LoRA request to use for generation, if any.
  484. prompt_adapter_request: Prompt Adapter request to use for
  485. generation, if any.
  486. Returns:
  487. A list of `EmbeddingRequestOutput` objects containing the
  488. generated embeddings in the same order as the input prompts.
  489. Note:
  490. Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
  491. considered legacy and may be deprecated in the future. You should
  492. instead pass them via the ``inputs`` parameter.
  493. """
  494. if not self.llm_engine.model_config.embedding_mode:
  495. raise ValueError(
  496. "LLM.encode() is only supported for embedding models (XModel)."
  497. )
  498. if prompt_token_ids is not None:
  499. parsed_prompts = self._convert_v1_inputs(
  500. prompts=cast(Optional[Union[str, List[str]]], prompts),
  501. prompt_token_ids=prompt_token_ids,
  502. )
  503. else:
  504. parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
  505. prompts)
  506. if pooling_params is None:
  507. # Use default pooling params.
  508. pooling_params = PoolingParams()
  509. self._validate_and_add_requests(
  510. prompts=parsed_prompts,
  511. params=pooling_params,
  512. lora_request=lora_request,
  513. prompt_adapter_request=prompt_adapter_request,
  514. )
  515. outputs = self._run_engine(use_tqdm=use_tqdm)
  516. return AphroditeEngine.validate_outputs(outputs, EmbeddingRequestOutput)
  517. # LEGACY
  518. def _convert_v1_inputs(
  519. self,
  520. prompts: Optional[Union[str, List[str]]],
  521. prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
  522. ):
  523. # skip_tokenizer_init is now checked in engine
  524. if prompts is not None:
  525. prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
  526. if prompt_token_ids is not None:
  527. prompt_token_ids = [
  528. p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
  529. ]
  530. num_requests = None
  531. if prompts is not None:
  532. num_requests = len(prompts)
  533. if prompt_token_ids is not None:
  534. if (num_requests is not None
  535. and num_requests != len(prompt_token_ids)):
  536. raise ValueError("The lengths of prompts and prompt_token_ids "
  537. "must be the same.")
  538. num_requests = len(prompt_token_ids)
  539. if num_requests is None:
  540. raise ValueError("Either prompts or prompt_token_ids must be "
  541. "provided.")
  542. parsed_prompts: List[PromptType] = []
  543. for i in range(num_requests):
  544. item: PromptType
  545. if prompts is not None:
  546. item = TextPrompt(prompt=prompts[i])
  547. elif prompt_token_ids is not None:
  548. item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
  549. else:
  550. raise AssertionError
  551. parsed_prompts.append(item)
  552. return parsed_prompts
  553. def _validate_and_add_requests(
  554. self,
  555. prompts: Union[PromptType, Sequence[PromptType]],
  556. params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
  557. Sequence[PoolingParams]],
  558. lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
  559. prompt_adapter_request: Optional[PromptAdapterRequest],
  560. guided_options: Optional[GuidedDecodingRequest] = None,
  561. ) -> None:
  562. if isinstance(prompts, (str, dict)):
  563. # Convert a single prompt to a list.
  564. prompts = [prompts]
  565. num_requests = len(prompts)
  566. if isinstance(params, list) and len(params) != num_requests:
  567. raise ValueError("The lengths of prompts and params "
  568. "must be the same.")
  569. if isinstance(lora_request,
  570. list) and len(lora_request) != num_requests:
  571. raise ValueError("The lengths of prompts and lora_request "
  572. "must be the same.")
  573. for sp in params if isinstance(params, list) else (params, ):
  574. if isinstance(sp, SamplingParams):
  575. self._add_guided_processor(sp, guided_options)
  576. # We only care about the final output
  577. sp.output_kind = RequestOutputKind.FINAL_ONLY
  578. # Add requests to the engine.
  579. for i, prompt in enumerate(prompts):
  580. self._add_request(
  581. prompt,
  582. params[i] if isinstance(params, Sequence) else params,
  583. lora_request=lora_request[i] if isinstance(
  584. lora_request, Sequence) else lora_request,
  585. prompt_adapter_request=prompt_adapter_request,
  586. )
  587. def _add_request(
  588. self,
  589. prompt: PromptType,
  590. params: Union[SamplingParams, PoolingParams],
  591. lora_request: Optional[LoRARequest] = None,
  592. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  593. ) -> None:
  594. request_id = str(next(self.request_counter))
  595. self.llm_engine.add_request(
  596. request_id,
  597. prompt,
  598. params,
  599. lora_request=lora_request,
  600. prompt_adapter_request=prompt_adapter_request,
  601. )
  602. def _add_guided_processor(
  603. self,
  604. params: SamplingParams,
  605. guided_options: Optional[GuidedDecodingRequest] = None):
  606. if guided_options:
  607. if guided_options.guided_decoding_backend is None:
  608. decoding_config = self.llm_engine.get_decoding_config()
  609. guided_options.guided_decoding_backend = (
  610. decoding_config.guided_decoding_backend)
  611. guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
  612. guided_options.guided_decoding_backend, guided_options,
  613. self.get_tokenizer())
  614. if guided_logits_processor:
  615. if params.logits_processors is None:
  616. params.logits_processors = []
  617. params.logits_processors.append(guided_logits_processor)
  618. return params
  619. def _run_engine(
  620. self, *, use_tqdm: bool
  621. ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  622. # Initialize tqdm.
  623. if use_tqdm:
  624. num_requests = self.llm_engine.get_num_unfinished_requests()
  625. pbar = tqdm(
  626. total=num_requests,
  627. desc="Processed prompts",
  628. dynamic_ncols=True,
  629. postfix=(f"est. speed input: {0:.2f} toks/s, "
  630. f"output: {0:.2f} toks/s"),
  631. )
  632. # Run the engine.
  633. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
  634. total_in_toks = 0
  635. total_out_toks = 0
  636. while self.llm_engine.has_unfinished_requests():
  637. step_outputs = self.llm_engine.step()
  638. for output in step_outputs:
  639. if output.finished:
  640. outputs.append(output)
  641. if use_tqdm:
  642. if isinstance(output, RequestOutput):
  643. # Calculate tokens only for RequestOutput
  644. assert output.prompt_token_ids is not None
  645. total_in_toks += len(output.prompt_token_ids)
  646. in_spd = total_in_toks / pbar.format_dict["elapsed"]
  647. total_out_toks += sum(
  648. len(stp.token_ids) for stp in output.outputs)
  649. out_spd = (total_out_toks /
  650. pbar.format_dict["elapsed"])
  651. pbar.postfix = (
  652. f"est. speed input: {in_spd:.2f} toks/s, "
  653. f"output: {out_spd:.2f} toks/s")
  654. pbar.update(1)
  655. if use_tqdm:
  656. pbar.close()
  657. # Sort the outputs by request ID.
  658. # This is necessary because some requests may be finished earlier than
  659. # its previous requests.
  660. return sorted(outputs, key=lambda x: int(x.request_id))
  661. def _is_encoder_decoder_model(self):
  662. return self.llm_engine.is_encoder_decoder_model()
  663. def _is_embedding_model(self):
  664. return self.llm_engine.is_embedding_model()