Pārlūkot izejas kodu

fix the LLM class for quantization

AlpinDale 1 gadu atpakaļ
vecāks
revīzija
c70abc7522
1 mainītis faili ar 58 papildinājumiem un 32 dzēšanām
  1. 58 32
      aphrodite/endpoints/llm.py

+ 58 - 32
aphrodite/endpoints/llm.py

@@ -13,33 +13,46 @@ from aphrodite.common.utils import Counter
 class LLM:
     """An LLM for generating texts from given prompts and sampling parameters.
 
-    This class includes a tokenizer, a language model (possible distributed
+    This class includes a tokenizer, a language model (possibly distributed
     across multiple GPUs), and GPU memory space allocated for intermediate
     states (aka KV cache). Given a batch of prompts and sampling parameters,
     this class generates texts from the model, using an intelligent batching
     mechanism and efficient memory management.
 
     NOTE: This class is intended to be used for offline inference. For online
-    serving, use the `AsyncAphrodite` class instead.
+    serving, use the `AsyncLLMEngine` class instead.
     NOTE: For the comprehensive list of arguments, see `EngineArgs`.
 
     Args:
-        model: The name or path of a compatible HuggingFace Transformer model.
-        tokenizer: The name or path of a HF Transformers tokenizer.
+        model: The name or path of a HuggingFace Transformers model.
+        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
         tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
             if available, and "slow" will always use the slow tokenizer.
         trust_remote_code: Trust remote code (e.g., from HuggingFace) when
             downloading the model and tokenizer.
-        tensor_parallel_size: The number of GPUs to use for distribtuted
+        tensor_parallel_size: The number of GPUs to use for distributed
             execution with tensor parallelism.
-        dtype: The datatype for the model weights and activations. Currently 
-            Aphrodite supports `float32`, `float16`, and `bfloat16`. If `auto`
-            is used, it'll use the `torch_dtype` attribute specified in the model
-            config file. However, if the `torch_dtype` in the config is `float32`,
-            we will use `bfloat16` if your GPU supports it, otherwise `float16`.
-        seed: The seed to initialize the RNG for sampling.
+        dtype: The data type for the model weights and activations. Currently,
+            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
+            the `torch_dtype` attribute specified in the model config file.
+            However, if the `torch_dtype` in the config is `float32`, we will
+            use `float16` instead.
+        quantization: The method used to quantize the model weights. Currently,
+            we support "awq". If None, we assume the model weights are not
+            quantized and use `dtype` to determine the data type of the weights.
         revision: The specific model version to use. It can be a branch name,
-            a tag name, or a commit ID.
+            a tag name, or a commit id.
+        seed: The seed to initialize the random number generator for sampling.
+        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
+            reserve for the model weights, activations, and KV cache. Higher
+            values will increase the KV cache size and thus improve the model's
+            throughput. However, if the value is too high, it may cause out-of-
+            memory (OOM) errors.
+        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
+            This can be used for temporarily storing the states of the requests
+            when their `best_of` sampling parameters are larger than 1. If all
+            requests will have `best_of=1`, you can safely set this to 0.
+            Otherwise, too small values may cause out-of-memory (OOM) errors.
     """
 
     def __init__(
@@ -50,7 +63,11 @@ class LLM:
         trust_remote_code: bool = False,
         tensor_parallel_size: int = 1,
         dtype: str = "auto",
+        quantization: Optional[str] = None,
+        revision: Optional[str] = None,
         seed: int = 0,
+        gpu_memory_utilization: float = 0.9,
+        swap_space: int = 4,
         **kwargs,
     ) -> None:
         if "disable_log_stats" not in kwargs:
@@ -62,21 +79,25 @@ class LLM:
             trust_remote_code=trust_remote_code,
             tensor_parallel_size=tensor_parallel_size,
             dtype=dtype,
+            quantization=quantization,
+            revision=revision,
             seed=seed,
+            gpu_memory_utilization=gpu_memory_utilization,
+            swap_space=swap_space,
             **kwargs,
         )
-        self.aphrodite_engine = AphroditeEngine.from_engine_args(engine_args)
+        self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
         self.request_counter = Counter()
-    
+
     def get_tokenizer(
             self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
-        return self.aphrodite_engine.tokenizer
-    
+        return self.llm_engine.tokenizer
+
     def set_tokenizer(
         self,
         tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
     ) -> None:
-        self.aphrodite_engine.tokenizer = tokenizer    
+        self.llm_engine.tokenizer = tokenizer
 
     def generate(
         self,
@@ -87,34 +108,37 @@ class LLM:
     ) -> List[RequestOutput]:
         """Generates the completions for the input prompts.
 
-        NOTE: This class automatically batches the given prompts, considering the
-        memory constraint. For the best performance, put all of your prompts into
-        a single list and pass it to this method.
+        NOTE: This class automatically batches the given prompts, considering
+        the memory constraint. For the best performance, put all of your prompts
+        into a single list and pass it to this method.
 
         Args:
             prompts: A list of prompts to generate completions for.
-            sampling_params: The sampling parameters for text generation. If None,
-                we use the default sampling parameters.
+            sampling_params: The sampling parameters for text generation. If
+                None, we use the default sampling parameters.
             prompt_token_ids: A list of token IDs for the prompts. If None, we
                 use the tokenizer to convert the prompts to token IDs.
             use_tqdm: Whether to use tqdm to display the progress bar.
 
         Returns:
-            A list of `RequestOutput` objects containing the generated completions
-            in the same order as the input prompts.
+            A list of `RequestOutput` objects containing the generated
+            completions in the same order as the input prompts.
         """
         if prompts is None and prompt_token_ids is None:
             raise ValueError("Either prompts or prompt_token_ids must be "
                              "provided.")
         if isinstance(prompts, str):
+            # Convert a single prompt to a list.
             prompts = [prompts]
         if prompts is not None and prompt_token_ids is not None:
             if len(prompts) != len(prompt_token_ids):
                 raise ValueError("The lengths of prompts and prompt_token_ids "
                                  "must be the same.")
         if sampling_params is None:
+            # Use default sampling params.
             sampling_params = SamplingParams()
 
+        # Add requests to the engine.
         if prompts is not None:
             num_requests = len(prompts)
         else:
@@ -135,25 +159,27 @@ class LLM:
         prompt_token_ids: Optional[List[int]],
     ) -> None:
         request_id = str(next(self.request_counter))
-        self.aphrodite_engine.add_request(request_id, prompt, sampling_params, prompt_token_ids)
+        self.llm_engine.add_request(request_id, prompt, sampling_params,
+                                    prompt_token_ids)
 
     def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
+        # Initialize tqdm.
         if use_tqdm:
-            num_requests = self.aphrodite_engine.get_num_unfinished_requests()
+            num_requests = self.llm_engine.get_num_unfinished_requests()
             pbar = tqdm(total=num_requests, desc="Processed prompts")
+        # Run the engine.
         outputs: List[RequestOutput] = []
-        while self.aphrodite_engine.has_unfinished_requests():
-            step_outputs = self.aphrodite_engine.step()
+        while self.llm_engine.has_unfinished_requests():
+            step_outputs = self.llm_engine.step()
             for output in step_outputs:
                 if output.finished:
                     outputs.append(output)
                     if use_tqdm:
                         pbar.update(1)
-
         if use_tqdm:
             pbar.close()
-            # Sort the outputs by request ID. Necessary because some outputs
-            # may be finished earlier than previous requests.
+        # Sort the outputs by request ID.
+        # This is necessary because some requests may be finished earlier than
+        # its previous requests.
         outputs = sorted(outputs, key=lambda x: int(x.request_id))
         return outputs
-