|
@@ -1560,6 +1560,7 @@ class CFGConfig:
|
|
|
def maybe_create_spec_config(
|
|
|
target_model_config: ModelConfig,
|
|
|
target_parallel_config: ParallelConfig,
|
|
|
+ target_scheduler_config: SchedulerConfig,
|
|
|
guidance_model: Optional[str],
|
|
|
):
|
|
|
if guidance_model is None:
|
|
@@ -1567,25 +1568,35 @@ class CFGConfig:
|
|
|
|
|
|
guidance_parallel_config = target_parallel_config
|
|
|
assert target_model_config.model == guidance_model
|
|
|
+ guidance_scheduler_config = target_scheduler_config
|
|
|
guidance_model_config = target_model_config
|
|
|
|
|
|
return CFGConfig(
|
|
|
guidance_model_config,
|
|
|
- guidance_parallel_config
|
|
|
+ guidance_parallel_config,
|
|
|
+ guidance_scheduler_config,
|
|
|
)
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
guidance_model_config: ModelConfig,
|
|
|
guidance_parallel_config: ParallelConfig,
|
|
|
+ guidance_scheduler_config: SchedulerConfig,
|
|
|
):
|
|
|
self.guidance_model_config = guidance_model_config
|
|
|
self.guidance_parallel_config = guidance_parallel_config
|
|
|
+ self.guidance_scheduler_config = guidance_scheduler_config
|
|
|
+ self._verify_args()
|
|
|
|
|
|
def _verify_args(self) -> None:
|
|
|
if self.guidance_model_config:
|
|
|
self.guidance_model_config.verify_with_parallel_config(
|
|
|
self.guidance_parallel_config)
|
|
|
+ if not self.guidance_scheduler_config.use_v2_block_manager:
|
|
|
+ raise ValueError(
|
|
|
+ "CFG requires usage of the V2 "
|
|
|
+ "block manager. Enable it with --use-v2-block-manager "
|
|
|
+ "or use_v2_block_manager=True.")
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
guidance_model = self.guidance_model_config.model
|