@@ -53,7 +53,7 @@ class CFGWorker(LoraNotSupportedWorkerBase):
assert self.parallel_config.pipeline_parallel_size == 1
def init_device(self):
- self.root_task_handler.init_device()
+ self.root_worker.init_device()
self.guidance_worker.init_device()
def load_model(self):