|
@@ -1,14 +1,17 @@
|
|
|
-from typing import List, Optional, Union
|
|
|
+from contextlib import contextmanager
|
|
|
+from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
|
|
|
|
|
|
-import torch
|
|
|
from tqdm import tqdm
|
|
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
|
|
|
|
+from aphrodite.common.inputs import (PromptInputs, PromptStrictInputs,
|
|
|
+ TextPrompt, TextTokensPrompt,
|
|
|
+ TokensPrompt, parse_and_batch_prompt)
|
|
|
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
|
|
|
from aphrodite.common.pooling_params import PoolingParams
|
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
|
from aphrodite.common.sequence import MultiModalData
|
|
|
-from aphrodite.common.utils import Counter
|
|
|
+from aphrodite.common.utils import Counter, deprecate_kwargs
|
|
|
from aphrodite.engine.aphrodite_engine import AphroditeEngine
|
|
|
from aphrodite.engine.args_tools import EngineArgs
|
|
|
from aphrodite.lora.request import LoRARequest
|
|
@@ -24,14 +27,19 @@ class LLM:
|
|
|
mechanism and efficient memory management.
|
|
|
|
|
|
NOTE: This class is intended to be used for offline inference. For online
|
|
|
- serving, use the `AsyncLLMEngine` class instead.
|
|
|
- NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
|
|
+ serving, use the :class:`~aphrodite.AsyncAphrodite` class instead.
|
|
|
+
|
|
|
+ NOTE: For the comprehensive list of arguments, see
|
|
|
+ :class:`~aphrodite.EngineArgs`.
|
|
|
|
|
|
Args:
|
|
|
model: The name or path of a HuggingFace Transformers model.
|
|
|
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
|
|
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
|
|
if available, and "slow" will always use the slow tokenizer.
|
|
|
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
|
|
|
+ detokenizer. Expect valid prompt_token_ids and None for prompt
|
|
|
+ from the input.
|
|
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
|
|
downloading the model and tokenizer.
|
|
|
tensor_parallel_size: The number of GPUs to use for distributed
|
|
@@ -42,12 +50,15 @@ class LLM:
|
|
|
However, if the `torch_dtype` in the config is `float32`, we will
|
|
|
use `float16` instead.
|
|
|
quantization: The method used to quantize the model weights. Currently,
|
|
|
- we support "awq", "gptq", "quip" and "squeezellm". If None,
|
|
|
- we first check the `quantization_config` attribute in the model
|
|
|
- config file. If that is None, we assume the model weights are not
|
|
|
- quantized and use `dtype` to determine the data type of the weights.
|
|
|
+ we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
|
|
|
+ If None, we first check the `quantization_config` attribute in the
|
|
|
+ model config file. If that is None, we assume the model weights are
|
|
|
+ not quantized and use `dtype` to determine the data type of
|
|
|
+ the weights.
|
|
|
revision: The specific model version to use. It can be a branch name,
|
|
|
a tag name, or a commit id.
|
|
|
+ tokenizer_revision: The specific tokenizer version to use. It can be a
|
|
|
+ branch name, a tag name, or a commit id.
|
|
|
seed: The seed to initialize the random number generator for sampling.
|
|
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
|
|
reserve for the model weights, activations, and KV cache. Higher
|
|
@@ -68,27 +79,40 @@ class LLM:
|
|
|
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
|
|
When a sequence has context length larger than this, we fall back
|
|
|
to eager mode.
|
|
|
- disable_custom_all_reduce: See ParallelConfig.
|
|
|
+ disable_custom_all_reduce: See ParallelConfig
|
|
|
"""
|
|
|
|
|
|
+ DEPRECATE_LEGACY: ClassVar[bool] = False
|
|
|
+ """A flag to toggle whether to deprecate the legacy generate/encode API."""
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ @contextmanager
|
|
|
+ def deprecate_legacy_api(cls):
|
|
|
+ cls.DEPRECATE_LEGACY = True
|
|
|
+
|
|
|
+ yield
|
|
|
+
|
|
|
+ cls.DEPRECATE_LEGACY = False
|
|
|
+
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: str,
|
|
|
tokenizer: Optional[str] = None,
|
|
|
tokenizer_mode: str = "auto",
|
|
|
+ skip_tokenizer_init: bool = False,
|
|
|
trust_remote_code: bool = False,
|
|
|
tensor_parallel_size: int = 1,
|
|
|
dtype: str = "auto",
|
|
|
quantization: Optional[str] = None,
|
|
|
revision: Optional[str] = None,
|
|
|
+ tokenizer_revision: Optional[str] = None,
|
|
|
seed: int = 0,
|
|
|
gpu_memory_utilization: float = 0.9,
|
|
|
swap_space: int = 4,
|
|
|
- enforce_eager: bool = True,
|
|
|
+ enforce_eager: bool = False,
|
|
|
max_context_len_to_capture: Optional[int] = None,
|
|
|
max_seq_len_to_capture: int = 8192,
|
|
|
disable_custom_all_reduce: bool = False,
|
|
|
- enable_prefix_caching: bool = False,
|
|
|
**kwargs,
|
|
|
) -> None:
|
|
|
if "disable_log_stats" not in kwargs:
|
|
@@ -97,11 +121,13 @@ class LLM:
|
|
|
model=model,
|
|
|
tokenizer=tokenizer,
|
|
|
tokenizer_mode=tokenizer_mode,
|
|
|
+ skip_tokenizer_init=skip_tokenizer_init,
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
|
dtype=dtype,
|
|
|
quantization=quantization,
|
|
|
revision=revision,
|
|
|
+ tokenizer_revision=tokenizer_revision,
|
|
|
seed=seed,
|
|
|
gpu_memory_utilization=gpu_memory_utilization,
|
|
|
swap_space=swap_space,
|
|
@@ -109,7 +135,6 @@ class LLM:
|
|
|
max_context_len_to_capture=max_context_len_to_capture,
|
|
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
|
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
|
|
- enable_prefix_caching=enable_prefix_caching,
|
|
|
**kwargs,
|
|
|
)
|
|
|
self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
|
|
@@ -117,23 +142,109 @@ class LLM:
|
|
|
|
|
|
def get_tokenizer(
|
|
|
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
|
|
- return self.llm_engine.tokenizer
|
|
|
+ return self.llm_engine.tokenizer.tokenizer
|
|
|
|
|
|
def set_tokenizer(
|
|
|
self,
|
|
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
|
|
) -> None:
|
|
|
- self.llm_engine.tokenizer = tokenizer
|
|
|
+ self.llm_engine.tokenizer.tokenizer = tokenizer
|
|
|
|
|
|
+ @overload # LEGACY: single (prompt + optional token ids)
|
|
|
def generate(
|
|
|
self,
|
|
|
- prompts: Optional[Union[str, List[str]]] = None,
|
|
|
+ prompts: str,
|
|
|
+ sampling_params: Optional[Union[SamplingParams,
|
|
|
+ List[SamplingParams]]] = None,
|
|
|
+ prompt_token_ids: Optional[List[int]] = None,
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[RequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: multi (prompt + optional token ids)
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ prompts: List[str],
|
|
|
sampling_params: Optional[Union[SamplingParams,
|
|
|
List[SamplingParams]]] = None,
|
|
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
|
|
use_tqdm: bool = True,
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[RequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: single (token ids + optional prompt)
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ prompts: Optional[str] = None,
|
|
|
+ sampling_params: Optional[Union[SamplingParams,
|
|
|
+ List[SamplingParams]]] = None,
|
|
|
+ *,
|
|
|
+ prompt_token_ids: List[int],
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[RequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: multi (token ids + optional prompt)
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ prompts: Optional[List[str]] = None,
|
|
|
+ sampling_params: Optional[Union[SamplingParams,
|
|
|
+ List[SamplingParams]]] = None,
|
|
|
+ *,
|
|
|
+ prompt_token_ids: List[List[int]],
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[RequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: single or multi token ids [pos-only]
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ prompts: None,
|
|
|
+ sampling_params: None,
|
|
|
+ prompt_token_ids: Union[List[int], List[List[int]]],
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[RequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ /, # We may enable `inputs` keyword after removing the old API
|
|
|
+ *,
|
|
|
+ sampling_params: Optional[Union[SamplingParams,
|
|
|
+ Sequence[SamplingParams]]] = None,
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ ) -> List[RequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @deprecate_kwargs("prompts",
|
|
|
+ "prompt_token_ids",
|
|
|
+ "multi_modal_data",
|
|
|
+ is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
|
|
+ additional_message="Please use the 'inputs' parameter "
|
|
|
+ "instead.")
|
|
|
+ def generate(
|
|
|
+ self,
|
|
|
+ prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ Optional[Union[str, List[str]]]] = None,
|
|
|
+ sampling_params: Optional[Union[SamplingParams,
|
|
|
+ Sequence[SamplingParams]]] = None,
|
|
|
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
) -> List[RequestOutput]:
|
|
|
"""Generates the completions for the input prompts.
|
|
|
|
|
@@ -142,170 +253,275 @@ class LLM:
|
|
|
into a single list and pass it to this method.
|
|
|
|
|
|
Args:
|
|
|
- prompts: A list of prompts to generate completions for.
|
|
|
+ inputs: A list of inputs to generate completions for.
|
|
|
sampling_params: The sampling parameters for text generation. If
|
|
|
- None, we use the default sampling parameters.
|
|
|
- When it's a single value, it's applied to every prompt.
|
|
|
- When it's a list, the list must have the same length as
|
|
|
- the prompts and it's paired one-to-one with the prompts.
|
|
|
- prompt_token_ids: A list of token IDs for the prompts. If None, we
|
|
|
- use the tokenizer to convert the prompts to token IDs.
|
|
|
+ None, we use the default sampling parameters.
|
|
|
+ When it is a single value, it is applied to every prompt.
|
|
|
+ When it is a list, the list must have the same length as the
|
|
|
+ prompts and it is paired one by one with the prompt.
|
|
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
|
|
lora_request: LoRA request to use for generation, if any.
|
|
|
- multi_modal_data: Multi-modal data to use for generation, if any.
|
|
|
|
|
|
Returns:
|
|
|
A list of `RequestOutput` objects containing the
|
|
|
generated completions in the same order as the input prompts.
|
|
|
"""
|
|
|
+ if prompt_token_ids is not None or multi_modal_data is not None:
|
|
|
+ inputs = self._convert_v1_inputs(
|
|
|
+ prompts=cast(Optional[Union[str, List[str]]], prompts),
|
|
|
+ prompt_token_ids=prompt_token_ids,
|
|
|
+ multi_modal_data=multi_modal_data,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ inputs = cast(
|
|
|
+ Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ prompts)
|
|
|
+
|
|
|
if sampling_params is None:
|
|
|
# Use default sampling params.
|
|
|
sampling_params = SamplingParams()
|
|
|
|
|
|
- requests_data = self._validate_and_prepare_requests(
|
|
|
- prompts,
|
|
|
- sampling_params,
|
|
|
- prompt_token_ids,
|
|
|
- lora_request,
|
|
|
- multi_modal_data,
|
|
|
+ self._validate_and_add_requests(
|
|
|
+ inputs=inputs,
|
|
|
+ params=sampling_params,
|
|
|
+ lora_request=lora_request,
|
|
|
)
|
|
|
|
|
|
- # Add requests to the engine and run the engine
|
|
|
- for request_data in requests_data:
|
|
|
- self._add_request(**request_data)
|
|
|
+ outputs = self._run_engine(use_tqdm=use_tqdm)
|
|
|
+ return AphroditeEngine.validate_outputs(outputs, RequestOutput)
|
|
|
|
|
|
- return self._run_engine(use_tqdm)
|
|
|
+ @overload # LEGACY: single (prompt + optional token ids)
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ prompts: str,
|
|
|
+ pooling_params: Optional[Union[PoolingParams,
|
|
|
+ Sequence[PoolingParams]]] = None,
|
|
|
+ prompt_token_ids: Optional[List[int]] = None,
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[EmbeddingRequestOutput]:
|
|
|
+ ...
|
|
|
|
|
|
+ @overload # LEGACY: multi (prompt + optional token ids)
|
|
|
def encode(
|
|
|
self,
|
|
|
- prompts: Optional[Union[str, List[str]]] = None,
|
|
|
+ prompts: List[str],
|
|
|
pooling_params: Optional[Union[PoolingParams,
|
|
|
- List[PoolingParams]]] = None,
|
|
|
+ Sequence[PoolingParams]]] = None,
|
|
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
|
|
use_tqdm: bool = True,
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[EmbeddingRequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: single (token ids + optional prompt)
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ prompts: Optional[str] = None,
|
|
|
+ pooling_params: Optional[Union[PoolingParams,
|
|
|
+ Sequence[PoolingParams]]] = None,
|
|
|
+ *,
|
|
|
+ prompt_token_ids: List[int],
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[EmbeddingRequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: multi (token ids + optional prompt)
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ prompts: Optional[List[str]] = None,
|
|
|
+ pooling_params: Optional[Union[PoolingParams,
|
|
|
+ Sequence[PoolingParams]]] = None,
|
|
|
+ *,
|
|
|
+ prompt_token_ids: List[List[int]],
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[EmbeddingRequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload # LEGACY: single or multi token ids [pos-only]
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ prompts: None,
|
|
|
+ pooling_params: None,
|
|
|
+ prompt_token_ids: Union[List[int], List[List[int]]],
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
+ ) -> List[EmbeddingRequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ /, # We may enable `inputs` keyword after removing the old API
|
|
|
+ *,
|
|
|
+ pooling_params: Optional[Union[PoolingParams,
|
|
|
+ Sequence[PoolingParams]]] = None,
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ ) -> List[EmbeddingRequestOutput]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @deprecate_kwargs("prompts",
|
|
|
+ "prompt_token_ids",
|
|
|
+ "multi_modal_data",
|
|
|
+ is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
|
|
+ additional_message="Please use the 'inputs' parameter "
|
|
|
+ "instead.")
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ Optional[Union[str, List[str]]]] = None,
|
|
|
+ pooling_params: Optional[Union[PoolingParams,
|
|
|
+ Sequence[PoolingParams]]] = None,
|
|
|
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
|
|
+ use_tqdm: bool = True,
|
|
|
+ lora_request: Optional[LoRARequest] = None,
|
|
|
+ multi_modal_data: Optional[MultiModalData] = None,
|
|
|
) -> List[EmbeddingRequestOutput]:
|
|
|
"""Generates the completions for the input prompts.
|
|
|
+
|
|
|
NOTE: This class automatically batches the given prompts, considering
|
|
|
the memory constraint. For the best performance, put all of your prompts
|
|
|
into a single list and pass it to this method.
|
|
|
+
|
|
|
Args:
|
|
|
- prompts: A list of prompts to generate completions for.
|
|
|
+ inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
|
|
+ batch inference. See
|
|
|
+ :class:`~aphrodite.inputs.PromptStrictInputs`
|
|
|
+ for more details about the format of each input.
|
|
|
pooling_params: The pooling parameters for pooling. If None, we
|
|
|
use the default pooling parameters.
|
|
|
- prompt_token_ids: A list of token IDs for the prompts. If None, we
|
|
|
- use the tokenizer to convert the prompts to token IDs.
|
|
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
|
|
lora_request: LoRA request to use for generation, if any.
|
|
|
- multi_modal_data: Multi modal data.
|
|
|
+
|
|
|
Returns:
|
|
|
A list of `EmbeddingRequestOutput` objects containing the
|
|
|
generated embeddings in the same order as the input prompts.
|
|
|
"""
|
|
|
+ if prompt_token_ids is not None or multi_modal_data is not None:
|
|
|
+ inputs = self._convert_v1_inputs(
|
|
|
+ prompts=cast(Optional[Union[str, List[str]]], prompts),
|
|
|
+ prompt_token_ids=prompt_token_ids,
|
|
|
+ multi_modal_data=multi_modal_data,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ inputs = cast(
|
|
|
+ Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ prompts)
|
|
|
+
|
|
|
if pooling_params is None:
|
|
|
# Use default pooling params.
|
|
|
pooling_params = PoolingParams()
|
|
|
|
|
|
- requests_data = self._validate_and_prepare_requests(
|
|
|
- prompts,
|
|
|
- pooling_params,
|
|
|
- prompt_token_ids,
|
|
|
- lora_request,
|
|
|
- multi_modal_data,
|
|
|
+ self._validate_and_add_requests(
|
|
|
+ inputs=inputs,
|
|
|
+ params=pooling_params,
|
|
|
+ lora_request=lora_request,
|
|
|
)
|
|
|
|
|
|
- # Add requests to the engine and run the engine
|
|
|
- for request_data in requests_data:
|
|
|
- self._add_request(**request_data)
|
|
|
-
|
|
|
- return self._run_engine(use_tqdm)
|
|
|
+ outputs = self._run_engine(use_tqdm=use_tqdm)
|
|
|
+ return AphroditeEngine.validate_outputs(outputs,
|
|
|
+ EmbeddingRequestOutput)
|
|
|
|
|
|
- def _validate_and_prepare_requests(
|
|
|
+ # LEGACY
|
|
|
+ def _convert_v1_inputs(
|
|
|
self,
|
|
|
prompts: Optional[Union[str, List[str]]],
|
|
|
- params: Union[Union[SamplingParams, PoolingParams],
|
|
|
- List[Union[SamplingParams,
|
|
|
- PoolingParams]]], # Unified parameter
|
|
|
- prompt_token_ids: Optional[List[List[int]]] = None,
|
|
|
- lora_request: Optional[LoRARequest] = None,
|
|
|
- multi_modal_data: Optional[MultiModalData] = None,
|
|
|
- ) -> List[dict]:
|
|
|
- """Validates and prepares request data for adding to the engine.
|
|
|
- Ensures prompts and token IDs are consistent, and returns a list of
|
|
|
- dictionaries with request data for further processing.
|
|
|
- """
|
|
|
- if prompts is None and prompt_token_ids is None:
|
|
|
- raise ValueError("Either prompts or prompt_token_ids must be "
|
|
|
- "provided.")
|
|
|
- if isinstance(prompts, str):
|
|
|
- # Convert a single prompt to a list.
|
|
|
- prompts = [prompts]
|
|
|
- if prompts is not None and prompt_token_ids is not None and len(
|
|
|
- prompts) != len(prompt_token_ids):
|
|
|
- raise ValueError(
|
|
|
- "The lengths of prompts and prompt_token_ids must "
|
|
|
- "be the same.")
|
|
|
+ prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
|
|
|
+ multi_modal_data: Optional[MultiModalData],
|
|
|
+ ):
|
|
|
+ # skip_tokenizer_init is now checked in engine
|
|
|
+
|
|
|
+ if prompts is not None:
|
|
|
+ prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
|
|
|
+ if prompt_token_ids is not None:
|
|
|
+ prompt_token_ids = [
|
|
|
+ p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
|
|
|
+ ]
|
|
|
|
|
|
+ num_requests = None
|
|
|
if prompts is not None:
|
|
|
num_requests = len(prompts)
|
|
|
- else:
|
|
|
- assert prompt_token_ids is not None
|
|
|
+ if prompt_token_ids is not None:
|
|
|
+ if (num_requests is not None
|
|
|
+ and num_requests != len(prompt_token_ids)):
|
|
|
+ raise ValueError("The lengths of prompts and prompt_token_ids "
|
|
|
+ "must be the same.")
|
|
|
+
|
|
|
num_requests = len(prompt_token_ids)
|
|
|
+ if num_requests is None:
|
|
|
+ raise ValueError("Either prompts or prompt_token_ids must be "
|
|
|
+ "provided.")
|
|
|
+
|
|
|
+ inputs: List[PromptInputs] = []
|
|
|
+ for i in range(num_requests):
|
|
|
+ if prompts is not None:
|
|
|
+ if prompt_token_ids is not None:
|
|
|
+ item = TextTokensPrompt(
|
|
|
+ prompt=prompts[i],
|
|
|
+ prompt_token_ids=prompt_token_ids[i])
|
|
|
+ else:
|
|
|
+ item = TextPrompt(prompt=prompts[i])
|
|
|
+ else:
|
|
|
+ if prompt_token_ids is not None:
|
|
|
+ item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
|
|
|
+ else:
|
|
|
+ raise AssertionError
|
|
|
+
|
|
|
+ if multi_modal_data is not None:
|
|
|
+ item["multi_modal_data"] = multi_modal_data
|
|
|
+
|
|
|
+ inputs.append(item)
|
|
|
+
|
|
|
+ return inputs
|
|
|
+
|
|
|
+ def _validate_and_add_requests(
|
|
|
+ self,
|
|
|
+ inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
|
|
+ params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
|
|
+ Sequence[PoolingParams]],
|
|
|
+ lora_request: Optional[LoRARequest],
|
|
|
+ ) -> None:
|
|
|
+ if isinstance(inputs, (str, dict)):
|
|
|
+ # Convert a single prompt to a list.
|
|
|
+ inputs = [inputs]
|
|
|
+
|
|
|
+ num_requests = len(inputs)
|
|
|
|
|
|
if isinstance(params, list) and len(params) != num_requests:
|
|
|
raise ValueError("The lengths of prompts and params "
|
|
|
- "be the same.")
|
|
|
-
|
|
|
- if multi_modal_data:
|
|
|
- multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
|
|
+ "must be the same.")
|
|
|
|
|
|
# Add requests to the engine.
|
|
|
- requests_data = []
|
|
|
- num_requests = len(prompts) if prompts is not None else len(
|
|
|
- prompt_token_ids)
|
|
|
- for i in range(num_requests):
|
|
|
- prompt = prompts[i] if prompts is not None else None
|
|
|
- token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
|
|
- i]
|
|
|
- multi_modal_item = MultiModalData(
|
|
|
- type=multi_modal_data.type,
|
|
|
- data=multi_modal_data.data[i].unsqueeze(0),
|
|
|
- ) if multi_modal_data else None
|
|
|
-
|
|
|
- requests_data.append({
|
|
|
- "prompt":
|
|
|
- prompt,
|
|
|
- "params":
|
|
|
- params[i] if isinstance(params, list) else params,
|
|
|
- "prompt_token_ids":
|
|
|
- token_ids,
|
|
|
- "lora_request":
|
|
|
- lora_request,
|
|
|
- "multi_modal_data":
|
|
|
- multi_modal_item,
|
|
|
- })
|
|
|
-
|
|
|
- return requests_data
|
|
|
+ for i, request_inputs in enumerate(inputs):
|
|
|
+ self._add_request(
|
|
|
+ request_inputs,
|
|
|
+ params[i] if isinstance(params, Sequence) else params,
|
|
|
+ lora_request=lora_request,
|
|
|
+ )
|
|
|
|
|
|
def _add_request(
|
|
|
self,
|
|
|
- prompt: Optional[str],
|
|
|
+ inputs: PromptInputs,
|
|
|
params: Union[SamplingParams, PoolingParams],
|
|
|
- prompt_token_ids: Optional[List[int]],
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
- multi_modal_data: Optional[MultiModalData] = None,
|
|
|
) -> None:
|
|
|
request_id = str(next(self.request_counter))
|
|
|
self.llm_engine.add_request(request_id,
|
|
|
- prompt,
|
|
|
+ inputs,
|
|
|
params,
|
|
|
- prompt_token_ids,
|
|
|
- lora_request=lora_request,
|
|
|
- multi_modal_data=multi_modal_data)
|
|
|
+ lora_request=lora_request)
|
|
|
|
|
|
def _run_engine(
|
|
|
- self, use_tqdm: bool
|
|
|
+ self, *, use_tqdm: bool
|
|
|
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
|
|
# Initialize tqdm.
|
|
|
if use_tqdm:
|
|
@@ -337,5 +553,4 @@ class LLM:
|
|
|
# Sort the outputs by request ID.
|
|
|
# This is necessary because some requests may be finished earlier than
|
|
|
# its previous requests.
|
|
|
- outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
|
|
- return outputs
|
|
|
+ return sorted(outputs, key=lambda x: int(x.request_id))
|