소스 검색

guard against using block manager v1

AlpinDale 4 달 전
부모
커밋
22425a03f8
2개의 변경된 파일18개의 추가작업 그리고 5개의 파일을 삭제
  1. 12 1
      aphrodite/common/config.py
  2. 6 4
      aphrodite/engine/args_tools.py

+ 12 - 1
aphrodite/common/config.py

@@ -1560,6 +1560,7 @@ class CFGConfig:
     def maybe_create_spec_config(
         target_model_config: ModelConfig,
         target_parallel_config: ParallelConfig,
+        target_scheduler_config: SchedulerConfig,
         guidance_model: Optional[str],
     ):
         if guidance_model is None:
@@ -1567,25 +1568,35 @@ class CFGConfig:
 
         guidance_parallel_config = target_parallel_config
         assert target_model_config.model == guidance_model
+        guidance_scheduler_config = target_scheduler_config
         guidance_model_config = target_model_config
 
         return CFGConfig(
             guidance_model_config,
-            guidance_parallel_config
+            guidance_parallel_config,
+            guidance_scheduler_config,
         )
 
     def __init__(
         self,
         guidance_model_config: ModelConfig,
         guidance_parallel_config: ParallelConfig,
+        guidance_scheduler_config: SchedulerConfig,
     ):
         self.guidance_model_config = guidance_model_config
         self.guidance_parallel_config = guidance_parallel_config
+        self.guidance_scheduler_config = guidance_scheduler_config
+        self._verify_args()
 
     def _verify_args(self) -> None:
         if self.guidance_model_config:
             self.guidance_model_config.verify_with_parallel_config(
                 self.guidance_parallel_config)
+        if not self.guidance_scheduler_config.use_v2_block_manager:
+            raise ValueError(
+                "CFG requires usage of the V2 "
+                "block manager. Enable it with --use-v2-block-manager "
+                "or use_v2_block_manager=True.")
 
     def __repr__(self) -> str:
         guidance_model = self.guidance_model_config.model

+ 6 - 4
aphrodite/engine/args_tools.py

@@ -1043,10 +1043,6 @@ class EngineArgs:
             if speculative_config is None \
             else speculative_config.num_lookahead_slots
 
-        cfg_config = CFGConfig.maybe_create_spec_config(
-            target_model_config=model_config,
-            target_parallel_config=parallel_config,
-            guidance_model=self.cfg_model)
 
         scheduler_config = SchedulerConfig(
             max_num_batched_tokens=self.max_num_batched_tokens,
@@ -1064,6 +1060,12 @@ class EngineArgs:
                              parallel_config.use_ray),
         )
 
+        cfg_config = CFGConfig.maybe_create_spec_config(
+            target_model_config=model_config,
+            target_parallel_config=parallel_config,
+            target_scheduler_config=scheduler_config,
+            guidance_model=self.cfg_model)
+
         if not HAS_TRITON and self.enable_lora:
             raise ValueError("Triton is not installed, LoRA will not work.")