AlpinDale пре 1 година
родитељ
комит
4cdf165ee9
1 измењених фајлова са 20 додато и 21 уклоњено
  1. 20 21
      aphrodite/engine/args_tools.py

+ 20 - 21
aphrodite/engine/args_tools.py

@@ -3,14 +3,12 @@ import dataclasses
 from dataclasses import dataclass
 from typing import Optional, Tuple
 
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                         SchedulerConfig)
-from torch import quantization
+from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig)
 
 
 @dataclass
 class EngineArgs:
-    """Arguments for Aphrodite engine."""
+    """Arguments for the Aphrodite engine."""
     model: str
     tokenizer: Optional[str] = None
     tokenizer_mode: str = 'auto'
@@ -30,7 +28,7 @@ class EngineArgs:
     max_num_seqs: int = 256
     disable_log_stats: bool = False
     revision: Optional[str] = None
-    quantization = Optional[str] = None
+    quantization: Optional[str] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -40,25 +38,25 @@ class EngineArgs:
     @staticmethod
     def add_cli_args(
             parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
-        """Shared CLI arguments for Aphrodite engine."""
+        """Shared CLI arguments for vLLM engine."""
         # Model arguments
         parser.add_argument(
             '--model',
             type=str,
-            default='PygmalionAI/pygmalion-2-7b',
+            default='facebook/opt-125m',
             help='name or path of the huggingface model to use')
         parser.add_argument(
             '--tokenizer',
             type=str,
             default=EngineArgs.tokenizer,
             help='name or path of the huggingface tokenizer to use')
-        parser.add_argument('--revision',
-                            type=str,
-                            default=None,
-                            help='the specific model version to use.'
-                            'It can be a branch name, a tag name, '
-                            'or a commit ID. If unspecified, will '
-                            'use the default version.')
+        parser.add_argument(
+            '--revision',
+            type=str,
+            default=None,
+            help='the specific model version to use. It can be a branch '
+            'name, a tag name, or a commit id. If unspecified, will use '
+            'the default version.')
         parser.add_argument('--tokenizer-mode',
                             type=str,
                             default=EngineArgs.tokenizer_mode,
@@ -90,7 +88,6 @@ class EngineArgs:
             'a numpy cache to speed up the loading. '
             '"dummy" will initialize the weights with random values, '
             'which is mainly for profiling.')
-        # TODO: Support FP32.
         parser.add_argument(
             '--dtype',
             type=str,
@@ -152,10 +149,13 @@ class EngineArgs:
         parser.add_argument('--disable-log-stats',
                             action='store_true',
                             help='disable logging statistics')
-        parser.add_argument('--quantization', '-q',
-                            type=str, choices=["awq", None],
+        # Quantization settings.
+        parser.add_argument('--quantization',
+                            '-q',
+                            type=str,
+                            choices=['awq', None],
                             default=None,
-                            help="Method used for quantization.")
+                            help='Method used to quantize the weights')
         return parser
 
     @classmethod
@@ -169,7 +169,6 @@ class EngineArgs:
     def create_engine_configs(
         self,
     ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
-        # Initialize the configs.
         model_config = ModelConfig(self.model, self.tokenizer,
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.download_dir, self.load_format,
@@ -189,7 +188,7 @@ class EngineArgs:
 
 @dataclass
 class AsyncEngineArgs(EngineArgs):
-    """Arguments for asynchronous Aphrodite engine."""
+    """Arguments for asynchronous vLLM engine."""
     engine_use_ray: bool = False
     disable_log_requests: bool = False
     max_log_len: Optional[int] = None
@@ -209,6 +208,6 @@ class AsyncEngineArgs(EngineArgs):
                             type=int,
                             default=None,
                             help='max number of prompt characters or prompt '
-                            'ID numbers being printed in the long. '
+                            'ID numbers being printed in log. '
                             'Default: unlimited.')
         return parser