Sfoglia il codice sorgente

fix: use named args

AlpinDale 7 mesi fa
parent
commit
ec5b99d075

+ 46 - 46
aphrodite/engine/args_tools.py

@@ -676,56 +676,56 @@ class EngineArgs:
                 "BitsAndBytes load format and QLoRA adapter only support "
                 f"'bitsandbytes' quantization, but got {self.quantization}")
 
-        device_config = DeviceConfig(self.device)
+        device_config = DeviceConfig(device=self.device)
+
         model_config = ModelConfig(
-            self.model,
-            self.tokenizer,
-            self.tokenizer_mode,
-            self.trust_remote_code,
-            self.dtype,
-            self.seed,
-            self.revision,
-            self.code_revision,
-            self.rope_scaling,
-            self.tokenizer_revision,
-            self.max_model_len,
-            self.quantization,
-            self.load_in_4bit,
-            self.load_in_8bit,
-            self.load_in_smooth,
-            self.deepspeed_fp_bits,
-            self.quantization_param_path,
-            self.enforce_eager,
-            self.max_context_len_to_capture,
-            self.max_seq_len_to_capture,
-            self.max_logprobs,
-            self.disable_sliding_window,
-            self.skip_tokenizer_init,
+            model=self.model,
+            tokenizer=self.tokenizer,
+            tokenizer_mode=self.tokenizer_mode,
+            trust_remote_code=self.trust_remote_code,
+            dtype=self.dtype,
+            seed=self.seed,
+            revision=self.revision,
+            code_revision=self.code_revision,
+            rope_scaling=self.rope_scaling,
+            tokenizer_revision=self.tokenizer_revision,
+            max_model_len=self.max_model_len,
+            quantization=self.quantization,
+            load_in_4bit=self.load_in_4bit,
+            load_in_8bit=self.load_in_8bit,
+            load_in_smooth=self.load_in_smooth,
+            deepspeed_fp_bits=self.deepspeed_fp_bits,
+            quantization_param_path=self.quantization_param_path,
+            enforce_eager=self.enforce_eager,
+            max_context_len_to_capture=self.max_context_len_to_capture,
+            max_seq_len_to_capture=self.max_seq_len_to_capture,
+            max_logprobs=self.max_logprobs,
+            disable_sliding_window=self.disable_sliding_window,
+            skip_tokenizer_init=self.skip_tokenizer_init,
         )
 
         cache_config = CacheConfig(
-            self.block_size,
-            self.gpu_memory_utilization,
-            self.swap_space,
-            self.kv_cache_dtype,
-            # self.kv_quant_params_path,
-            self.num_gpu_blocks_override,
-            model_config.get_sliding_window(),
-            self.enable_prefix_caching,
+            block_size=self.block_size,
+            gpu_memory_utilization=self.gpu_memory_utilization,
+            swap_space=self.swap_space,
+            cache_dtype=self.kv_cache_dtype,
+            num_gpu_blocks_override=self.num_gpu_blocks_override,
+            sliding_window=model_config.get_sliding_window(),
+            enable_prefix_caching=self.enable_prefix_caching,
         )
 
         parallel_config = ParallelConfig(
-            self.pipeline_parallel_size,
-            self.tensor_parallel_size,
-            self.worker_use_ray,
-            self.max_parallel_loading_workers,
-            self.disable_custom_all_reduce,
-            TokenizerPoolConfig.create_config(
-                self.tokenizer_pool_size,
-                self.tokenizer_pool_type,
-                self.tokenizer_pool_extra_config,
+            pipeline_parallel_size=self.pipeline_parallel_size,
+            tensor_parallel_size=self.tensor_parallel_size,
+            worker_use_ray=self.worker_use_ray,
+            max_parallel_loading_workers=self.max_parallel_loading_workers,
+            disable_custom_all_reduce=self.disable_custom_all_reduce,
+            tokenizer_pool_config=TokenizerPoolConfig.create_config(
+                tokenizer_pool_size=self.tokenizer_pool_size,
+                tokenizer_pool_type=self.tokenizer_pool_type,
+                tokenizer_pool_extra_config=self.tokenizer_pool_extra_config,
             ),
-            self.ray_workers_use_nsight,
+            ray_workers_use_nsight=self.ray_workers_use_nsight,
             distributed_executor_backend=self.distributed_executor_backend)
 
         speculative_config = SpeculativeConfig.maybe_create_spec_config(
@@ -744,10 +744,10 @@ class EngineArgs:
         )
 
         scheduler_config = SchedulerConfig(
-            self.max_num_batched_tokens,
-            self.max_num_seqs,
-            model_config.max_model_len,
-            self.use_v2_block_manager,
+            max_num_batched_tokens=self.max_num_batched_tokens,
+            max_num_seqs=self.max_num_seqs,
+            max_model_len=model_config.max_model_len,
+            use_v2_block_manager=self.use_v2_block_manager,
             num_lookahead_slots=(self.num_lookahead_slots
                                  if speculative_config is None else
                                  speculative_config.num_lookahead_slots),

+ 2 - 1
aphrodite/spec_decode/proposer_worker_base.py

@@ -31,6 +31,7 @@ class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
     def list_loras(self) -> Set[int]:
         raise ValueError(f"{type(self)} does not support LoRA")
 
+
 class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
     """Proposer worker which does not use a model with kvcache"""
 
@@ -51,7 +52,7 @@ class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
 
     def get_cache_block_size_bytes(self) -> int:
         return 0
-    
+
     def add_lora(self, lora_request: LoRARequest) -> bool:
         raise ValueError(f"{type(self)} does not support LoRA")