123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600 |
- from contextlib import contextmanager
- from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
- from tqdm import tqdm
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
- from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.utils import Counter, deprecate_kwargs
- from aphrodite.engine.aphrodite_engine import AphroditeEngine
- from aphrodite.engine.args_tools import EngineArgs
- from aphrodite.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
- TextTokensPrompt, TokensPrompt,
- parse_and_batch_prompt)
- from aphrodite.lora.request import LoRARequest
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- from aphrodite.transformers_utils.tokenizer import get_cached_tokenizer
- class LLM:
- """An LLM for generating texts from given prompts and sampling parameters.
- This class includes a tokenizer, a language model (possibly distributed
- across multiple GPUs), and GPU memory space allocated for intermediate
- states (aka KV cache). Given a batch of prompts and sampling parameters,
- this class generates texts from the model, using an intelligent batching
- mechanism and efficient memory management.
- Args:
- model: The name or path of a HuggingFace Transformers model.
- tokenizer: The name or path of a HuggingFace Transformers tokenizer.
- tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
- if available, and "slow" will always use the slow tokenizer.
- skip_tokenizer_init: If true, skip initialization of tokenizer and
- detokenizer. Expect valid prompt_token_ids and None for prompt
- from the input.
- trust_remote_code: Trust remote code (e.g., from HuggingFace) when
- downloading the model and tokenizer.
- tensor_parallel_size: The number of GPUs to use for distributed
- execution with tensor parallelism.
- dtype: The data type for the model weights and activations. Currently,
- we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
- the `torch_dtype` attribute specified in the model config file.
- However, if the `torch_dtype` in the config is `float32`, we will
- use `float16` instead.
- quantization: The method used to quantize the model weights. Currently,
- we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
- If None, we first check the `quantization_config` attribute in the
- model config file. If that is None, we assume the model weights are
- not quantized and use `dtype` to determine the data type of
- the weights.
- revision: The specific model version to use. It can be a branch name,
- a tag name, or a commit id.
- tokenizer_revision: The specific tokenizer version to use. It can be a
- branch name, a tag name, or a commit id.
- seed: The seed to initialize the random number generator for sampling.
- gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
- reserve for the model weights, activations, and KV cache. Higher
- values will increase the KV cache size and thus improve the model's
- throughput. However, if the value is too high, it may cause out-of-
- memory (OOM) errors.
- swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
- This can be used for temporarily storing the states of the requests
- when their `best_of` sampling parameters are larger than 1. If all
- requests will have `best_of=1`, you can safely set this to 0.
- Otherwise, too small values may cause out-of-memory (OOM) errors.
- cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
- the model weights. This virtually increases the GPU memory space
- you can use to hold the model weights, at the cost of CPU-GPU data
- transfer for every forward pass.
- enforce_eager: Whether to enforce eager execution. If True, we will
- disable CUDA graph and always execute the model in eager mode.
- If False, we will use CUDA graph and eager execution in hybrid.
- max_context_len_to_capture: Maximum context len covered by CUDA graphs.
- When a sequence has context length larger than this, we fall back
- to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
- max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
- When a sequence has context length larger than this, we fall back
- to eager mode.
- disable_custom_all_reduce: See ParallelConfig
- **kwargs: Arguments for :class:`~aphrodite.EngineArgs`. (See
- :ref:`engine_args`)
-
- Note:
- This class is intended to be used for offline inference. For online
- serving, use the :class:`~aphrodite.AsyncAphrodite` class instead.
- """
- DEPRECATE_LEGACY: ClassVar[bool] = False
- """A flag to toggle whether to deprecate the legacy generate/encode API."""
- @classmethod
- @contextmanager
- def deprecate_legacy_api(cls):
- cls.DEPRECATE_LEGACY = True
- yield
- cls.DEPRECATE_LEGACY = False
- def __init__(
- self,
- model: str,
- tokenizer: Optional[str] = None,
- tokenizer_mode: str = "auto",
- skip_tokenizer_init: bool = False,
- trust_remote_code: bool = False,
- tensor_parallel_size: int = 1,
- dtype: str = "auto",
- quantization: Optional[str] = None,
- revision: Optional[str] = None,
- tokenizer_revision: Optional[str] = None,
- seed: int = 0,
- gpu_memory_utilization: float = 0.9,
- swap_space: int = 4,
- cpu_offload_gb: float = 0,
- enforce_eager: bool = False,
- max_context_len_to_capture: Optional[int] = None,
- max_seq_len_to_capture: int = 8192,
- disable_custom_all_reduce: bool = False,
- **kwargs,
- ) -> None:
- if "disable_log_stats" not in kwargs:
- kwargs["disable_log_stats"] = True
- removed_vision_keys = ("image_token_id", "image_feature_size",
- "image_input_shape", "image_input_type")
- if any(k in kwargs for k in removed_vision_keys):
- raise TypeError(
- "There is no need to pass vision-related arguments anymore.")
- engine_args = EngineArgs(
- model=model,
- tokenizer=tokenizer,
- tokenizer_mode=tokenizer_mode,
- skip_tokenizer_init=skip_tokenizer_init,
- trust_remote_code=trust_remote_code,
- tensor_parallel_size=tensor_parallel_size,
- dtype=dtype,
- quantization=quantization,
- revision=revision,
- tokenizer_revision=tokenizer_revision,
- seed=seed,
- gpu_memory_utilization=gpu_memory_utilization,
- swap_space=swap_space,
- cpu_offload_gb=cpu_offload_gb,
- enforce_eager=enforce_eager,
- max_context_len_to_capture=max_context_len_to_capture,
- max_seq_len_to_capture=max_seq_len_to_capture,
- disable_custom_all_reduce=disable_custom_all_reduce,
- **kwargs,
- )
- self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
- self.request_counter = Counter()
- def get_tokenizer(
- self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
- return self.llm_engine.tokenizer.tokenizer
- def set_tokenizer(
- self,
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
- ) -> None:
- # While CachedTokenizer is dynamic, have no choice but
- # compare class name. Misjudgment will arise from
- # user-defined tokenizer started with 'Cached'
- if tokenizer.__class__.__name__.startswith("Cached"):
- self.llm_engine.tokenizer.tokenizer = tokenizer
- else:
- self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
- tokenizer)
- @overload # LEGACY: single (prompt + optional token ids)
- def generate(
- self,
- prompts: str,
- sampling_params: Optional[Union[SamplingParams,
- List[SamplingParams]]] = None,
- prompt_token_ids: Optional[List[int]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[RequestOutput]:
- ...
- @overload # LEGACY: multi (prompt + optional token ids)
- def generate(
- self,
- prompts: List[str],
- sampling_params: Optional[Union[SamplingParams,
- List[SamplingParams]]] = None,
- prompt_token_ids: Optional[List[List[int]]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[RequestOutput]:
- ...
- @overload # LEGACY: single (token ids + optional prompt)
- def generate(
- self,
- prompts: Optional[str] = None,
- sampling_params: Optional[Union[SamplingParams,
- List[SamplingParams]]] = None,
- *,
- prompt_token_ids: List[int],
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[RequestOutput]:
- ...
- @overload # LEGACY: multi (token ids + optional prompt)
- def generate(
- self,
- prompts: Optional[List[str]] = None,
- sampling_params: Optional[Union[SamplingParams,
- List[SamplingParams]]] = None,
- *,
- prompt_token_ids: List[List[int]],
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[RequestOutput]:
- ...
- @overload # LEGACY: single or multi token ids [pos-only]
- def generate(
- self,
- prompts: None,
- sampling_params: None,
- prompt_token_ids: Union[List[int], List[List[int]]],
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[RequestOutput]:
- ...
- @overload
- def generate(
- self,
- inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- /, # We may enable `inputs` keyword after removing the old API
- *,
- sampling_params: Optional[Union[SamplingParams,
- Sequence[SamplingParams]]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[RequestOutput]:
- ...
- @deprecate_kwargs("prompts",
- "prompt_token_ids",
- is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
- additional_message="Please use the 'inputs' parameter "
- "instead.")
- def generate(
- self,
- prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- Optional[Union[str, List[str]]]] = None,
- sampling_params: Optional[Union[SamplingParams,
- Sequence[SamplingParams]]] = None,
- prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> List[RequestOutput]:
- """Generates the completions for the input prompts.
- This class automatically batches the given prompts, considering
- the memory constraint. For the best performance, put all of your prompts
- into a single list and pass it to this method.
- Args:
- inputs: A list of inputs to generate completions for.
- sampling_params: The sampling parameters for text generation. If
- None, we use the default sampling parameters.
- When it is a single value, it is applied to every prompt.
- When it is a list, the list must have the same length as the
- prompts and it is paired one by one with the prompt.
- use_tqdm: Whether to use tqdm to display the progress bar.
- lora_request: LoRA request to use for generation, if any.
- prompt_adapter_request: Prompt Adapter request to use for
- generation, if any.
- Returns:
- A list of `RequestOutput` objects containing the
- generated completions in the same order as the input prompts.
- Note:
- Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
- considered legacy and may be deprecated in the future. You should
- instead pass them via the ``inputs`` parameter.
- """
- if self.llm_engine.model_config.embedding_mode:
- raise ValueError(
- "LLM.generate() is only supported for generation models "
- "(XForCausalLM).")
- if prompt_token_ids is not None:
- inputs = self._convert_v1_inputs(
- prompts=cast(Optional[Union[str, List[str]]], prompts),
- prompt_token_ids=prompt_token_ids,
- )
- else:
- inputs = cast(
- Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- prompts)
- if sampling_params is None:
- # Use default sampling params.
- sampling_params = SamplingParams()
- self._validate_and_add_requests(
- inputs=inputs,
- params=sampling_params,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- outputs = self._run_engine(use_tqdm=use_tqdm)
- return AphroditeEngine.validate_outputs(outputs, RequestOutput)
- @overload # LEGACY: single (prompt + optional token ids)
- def encode(
- self,
- prompts: str,
- pooling_params: Optional[Union[PoolingParams,
- Sequence[PoolingParams]]] = None,
- prompt_token_ids: Optional[List[int]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[EmbeddingRequestOutput]:
- ...
- @overload # LEGACY: multi (prompt + optional token ids)
- def encode(
- self,
- prompts: List[str],
- pooling_params: Optional[Union[PoolingParams,
- Sequence[PoolingParams]]] = None,
- prompt_token_ids: Optional[List[List[int]]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[EmbeddingRequestOutput]:
- ...
- @overload # LEGACY: single (token ids + optional prompt)
- def encode(
- self,
- prompts: Optional[str] = None,
- pooling_params: Optional[Union[PoolingParams,
- Sequence[PoolingParams]]] = None,
- *,
- prompt_token_ids: List[int],
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[EmbeddingRequestOutput]:
- ...
- @overload # LEGACY: multi (token ids + optional prompt)
- def encode(
- self,
- prompts: Optional[List[str]] = None,
- pooling_params: Optional[Union[PoolingParams,
- Sequence[PoolingParams]]] = None,
- *,
- prompt_token_ids: List[List[int]],
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[EmbeddingRequestOutput]:
- ...
- @overload # LEGACY: single or multi token ids [pos-only]
- def encode(
- self,
- prompts: None,
- pooling_params: None,
- prompt_token_ids: Union[List[int], List[List[int]]],
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[EmbeddingRequestOutput]:
- ...
- @overload
- def encode(
- self,
- inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- /, # We may enable `inputs` keyword after removing the old API
- *,
- pooling_params: Optional[Union[PoolingParams,
- Sequence[PoolingParams]]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- ) -> List[EmbeddingRequestOutput]:
- ...
- @deprecate_kwargs("prompts",
- "prompt_token_ids",
- is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
- additional_message="Please use the 'inputs' parameter "
- "instead.")
- def encode(
- self,
- prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- Optional[Union[str, List[str]]]] = None,
- pooling_params: Optional[Union[PoolingParams,
- Sequence[PoolingParams]]] = None,
- prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
- use_tqdm: bool = True,
- lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> List[EmbeddingRequestOutput]:
- """Generates the completions for the input prompts.
- This class automatically batches the given prompts, considering
- the memory constraint. For the best performance, put all of your prompts
- into a single list and pass it to this method.
- Args:
- inputs: The inputs to the LLM. You may pass a sequence of inputs for
- batch inference. See
- :class:`~aphrodite.inputs.PromptStrictInputs`
- for more details about the format of each input.
- pooling_params: The pooling parameters for pooling. If None, we
- use the default pooling parameters.
- use_tqdm: Whether to use tqdm to display the progress bar.
- lora_request: LoRA request to use for generation, if any.
- prompt_adapter_request: Prompt Adapter request to use for
- generation, if any.
- Returns:
- A list of `EmbeddingRequestOutput` objects containing the
- generated embeddings in the same order as the input prompts.
- Note:
- Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
- considered legacy and may be deprecated in the future. You should
- instead pass them via the ``inputs`` parameter.
- """
- if not self.llm_engine.model_config.embedding_mode:
- raise ValueError(
- "LLM.encode() is only supported for embedding models (XModel)."
- )
- if prompt_token_ids is not None:
- inputs = self._convert_v1_inputs(
- prompts=cast(Optional[Union[str, List[str]]], prompts),
- prompt_token_ids=prompt_token_ids,
- )
- else:
- inputs = cast(
- Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- prompts)
- if pooling_params is None:
- # Use default pooling params.
- pooling_params = PoolingParams()
- self._validate_and_add_requests(
- inputs=inputs,
- params=pooling_params,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request)
- outputs = self._run_engine(use_tqdm=use_tqdm)
- return AphroditeEngine.validate_outputs(outputs,
- EmbeddingRequestOutput)
- # LEGACY
- def _convert_v1_inputs(
- self,
- prompts: Optional[Union[str, List[str]]],
- prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
- ):
- # skip_tokenizer_init is now checked in engine
- if prompts is not None:
- prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
- if prompt_token_ids is not None:
- prompt_token_ids = [
- p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
- ]
- num_requests = None
- if prompts is not None:
- num_requests = len(prompts)
- if prompt_token_ids is not None:
- if (num_requests is not None
- and num_requests != len(prompt_token_ids)):
- raise ValueError("The lengths of prompts and prompt_token_ids "
- "must be the same.")
- num_requests = len(prompt_token_ids)
- if num_requests is None:
- raise ValueError("Either prompts or prompt_token_ids must be "
- "provided.")
- inputs: List[PromptInputs] = []
- for i in range(num_requests):
- if prompts is not None:
- if prompt_token_ids is not None:
- item = TextTokensPrompt(
- prompt=prompts[i],
- prompt_token_ids=prompt_token_ids[i])
- else:
- item = TextPrompt(prompt=prompts[i])
- else:
- if prompt_token_ids is not None:
- item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
- else:
- raise AssertionError
- inputs.append(item)
- return inputs
- def _validate_and_add_requests(
- self,
- inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
- params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
- Sequence[PoolingParams]],
- lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
- prompt_adapter_request: Optional[PromptAdapterRequest],
- ) -> None:
- if isinstance(inputs, (str, dict)):
- # Convert a single prompt to a list.
- inputs = [inputs]
- num_requests = len(inputs)
- if isinstance(params, list) and len(params) != num_requests:
- raise ValueError("The lengths of prompts and params "
- "must be the same.")
- if isinstance(lora_request,
- list) and len(lora_request) != num_requests:
- raise ValueError("The lengths of prompts and lora_request "
- "must be the same.")
- # Add requests to the engine.
- for i, request_inputs in enumerate(inputs):
- self._add_request(
- request_inputs,
- params[i] if isinstance(params, Sequence) else params,
- lora_request=lora_request[i] if isinstance(
- lora_request, Sequence) else lora_request,
- prompt_adapter_request=prompt_adapter_request,
- )
- def _add_request(
- self,
- inputs: PromptInputs,
- params: Union[SamplingParams, PoolingParams],
- lora_request: Optional[Union[List[LoRARequest],
- LoRARequest]] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> None:
- request_id = str(next(self.request_counter))
- self.llm_engine.add_request(
- request_id,
- inputs,
- params,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request)
- def _run_engine(
- self, *, use_tqdm: bool
- ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
- # Initialize tqdm.
- if use_tqdm:
- num_requests = self.llm_engine.get_num_unfinished_requests()
- pbar = tqdm(
- total=num_requests,
- desc="Processed prompts",
- dynamic_ncols=True,
- postfix=(f"est. speed input: {0:.2f} toks/s, "
- f"output: {0:.2f} toks/s"),
- )
- # Run the engine.
- outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
- total_in_toks = 0
- total_out_toks = 0
- while self.llm_engine.has_unfinished_requests():
- step_outputs = self.llm_engine.step()
- for output in step_outputs:
- if output.finished:
- outputs.append(output)
- if use_tqdm:
- if isinstance(output, RequestOutput):
- # Calculate tokens only for RequestOutput
- total_in_toks += len(output.prompt_token_ids)
- in_spd = total_in_toks / pbar.format_dict["elapsed"]
- total_out_toks += sum(
- len(stp.token_ids) for stp in output.outputs)
- out_spd = total_out_toks / pbar.format_dict[
- "elapsed"]
- pbar.postfix = (
- f"est. speed input: {in_spd:.2f} toks/s, "
- f"output: {out_spd:.2f} toks/s")
- pbar.update(1)
- if use_tqdm:
- pbar.close()
- # Sort the outputs by request ID.
- # This is necessary because some requests may be finished earlier than
- # its previous requests.
- return sorted(outputs, key=lambda x: int(x.request_id))
|