123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- import argparse, os, sys, glob, math, time
- import torch
- import numpy as np
- from omegaconf import OmegaConf
- from PIL import Image
- from main import instantiate_from_config, DataModuleFromConfig
- from torch.utils.data import DataLoader
- from torch.utils.data.dataloader import default_collate
- from tqdm import trange
- def save_image(x, path):
- c,h,w = x.shape
- assert c==3
- x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
- Image.fromarray(x).save(path)
- @torch.no_grad()
- def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
- if len(dsets.datasets) > 1:
- split = sorted(dsets.datasets.keys())[0]
- dset = dsets.datasets[split]
- else:
- dset = next(iter(dsets.datasets.values()))
- print("Dataset: ", dset.__class__.__name__)
- for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
- indices = list(range(start_idx, start_idx+batch_size))
- example = default_collate([dset[i] for i in indices])
- x = model.get_input("image", example).to(model.device)
- for i in range(x.shape[0]):
- save_image(x[i], os.path.join(outdir, "originals",
- "{:06}.png".format(indices[i])))
- cond_key = model.cond_stage_key
- c = model.get_input(cond_key, example).to(model.device)
- scale_factor = 1.0
- quant_z, z_indices = model.encode_to_z(x)
- quant_c, c_indices = model.encode_to_c(c)
- cshape = quant_z.shape
- xrec = model.first_stage_model.decode(quant_z)
- for i in range(xrec.shape[0]):
- save_image(xrec[i], os.path.join(outdir, "reconstructions",
- "{:06}.png".format(indices[i])))
- if cond_key == "segmentation":
- # get image from segmentation mask
- num_classes = c.shape[1]
- c = torch.argmax(c, dim=1, keepdim=True)
- c = torch.nn.functional.one_hot(c, num_classes=num_classes)
- c = c.squeeze(1).permute(0, 3, 1, 2).float()
- c = model.cond_stage_model.to_rgb(c)
- idx = z_indices
- half_sample = False
- if half_sample:
- start = idx.shape[1]//2
- else:
- start = 0
- idx[:,start:] = 0
- idx = idx.reshape(cshape[0],cshape[2],cshape[3])
- start_i = start//cshape[3]
- start_j = start %cshape[3]
- cidx = c_indices
- cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
- sample = True
- for i in range(start_i,cshape[2]-0):
- if i <= 8:
- local_i = i
- elif cshape[2]-i < 8:
- local_i = 16-(cshape[2]-i)
- else:
- local_i = 8
- for j in range(start_j,cshape[3]-0):
- if j <= 8:
- local_j = j
- elif cshape[3]-j < 8:
- local_j = 16-(cshape[3]-j)
- else:
- local_j = 8
- i_start = i-local_i
- i_end = i_start+16
- j_start = j-local_j
- j_end = j_start+16
- patch = idx[:,i_start:i_end,j_start:j_end]
- patch = patch.reshape(patch.shape[0],-1)
- cpatch = cidx[:, i_start:i_end, j_start:j_end]
- cpatch = cpatch.reshape(cpatch.shape[0], -1)
- patch = torch.cat((cpatch, patch), dim=1)
- logits,_ = model.transformer(patch[:,:-1])
- logits = logits[:, -256:, :]
- logits = logits.reshape(cshape[0],16,16,-1)
- logits = logits[:,local_i,local_j,:]
- logits = logits/temperature
- if top_k is not None:
- logits = model.top_k_logits(logits, top_k)
- # apply softmax to convert to probabilities
- probs = torch.nn.functional.softmax(logits, dim=-1)
- # sample from the distribution or take the most likely
- if sample:
- ix = torch.multinomial(probs, num_samples=1)
- else:
- _, ix = torch.topk(probs, k=1, dim=-1)
- idx[:,i,j] = ix
- xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
- for i in range(xsample.shape[0]):
- save_image(xsample[i], os.path.join(outdir, "samples",
- "{:06}.png".format(indices[i])))
- def get_parser():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-r",
- "--resume",
- type=str,
- nargs="?",
- help="load from logdir or checkpoint in logdir",
- )
- parser.add_argument(
- "-b",
- "--base",
- nargs="*",
- metavar="base_config.yaml",
- help="paths to base configs. Loaded from left-to-right. "
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
- default=list(),
- )
- parser.add_argument(
- "-c",
- "--config",
- nargs="?",
- metavar="single_config.yaml",
- help="path to single config. If specified, base configs will be ignored "
- "(except for the last one if left unspecified).",
- const=True,
- default="",
- )
- parser.add_argument(
- "--ignore_base_data",
- action="store_true",
- help="Ignore data specification from base configs. Useful if you want "
- "to specify a custom datasets on the command line.",
- )
- parser.add_argument(
- "--outdir",
- required=True,
- type=str,
- help="Where to write outputs to.",
- )
- parser.add_argument(
- "--top_k",
- type=int,
- default=100,
- help="Sample from among top-k predictions.",
- )
- parser.add_argument(
- "--temperature",
- type=float,
- default=1.0,
- help="Sampling temperature.",
- )
- return parser
- def load_model_from_config(config, sd, gpu=True, eval_mode=True):
- if "ckpt_path" in config.params:
- print("Deleting the restore-ckpt path from the config...")
- config.params.ckpt_path = None
- if "downsample_cond_size" in config.params:
- print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
- config.params.downsample_cond_size = -1
- config.params["downsample_cond_factor"] = 0.5
- try:
- if "ckpt_path" in config.params.first_stage_config.params:
- config.params.first_stage_config.params.ckpt_path = None
- print("Deleting the first-stage restore-ckpt path from the config...")
- if "ckpt_path" in config.params.cond_stage_config.params:
- config.params.cond_stage_config.params.ckpt_path = None
- print("Deleting the cond-stage restore-ckpt path from the config...")
- except:
- pass
- model = instantiate_from_config(config)
- if sd is not None:
- missing, unexpected = model.load_state_dict(sd, strict=False)
- print(f"Missing Keys in State Dict: {missing}")
- print(f"Unexpected Keys in State Dict: {unexpected}")
- if gpu:
- model.cuda()
- if eval_mode:
- model.eval()
- return {"model": model}
- def get_data(config):
- # get data
- data = instantiate_from_config(config.data)
- data.prepare_data()
- data.setup()
- return data
- def load_model_and_dset(config, ckpt, gpu, eval_mode):
- # get data
- dsets = get_data(config) # calls data.config ...
- # now load the specified checkpoint
- if ckpt:
- pl_sd = torch.load(ckpt, map_location="cpu")
- global_step = pl_sd["global_step"]
- else:
- pl_sd = {"state_dict": None}
- global_step = None
- model = load_model_from_config(config.model,
- pl_sd["state_dict"],
- gpu=gpu,
- eval_mode=eval_mode)["model"]
- return dsets, model, global_step
- if __name__ == "__main__":
- sys.path.append(os.getcwd())
- parser = get_parser()
- opt, unknown = parser.parse_known_args()
- ckpt = None
- if opt.resume:
- if not os.path.exists(opt.resume):
- raise ValueError("Cannot find {}".format(opt.resume))
- if os.path.isfile(opt.resume):
- paths = opt.resume.split("/")
- try:
- idx = len(paths)-paths[::-1].index("logs")+1
- except ValueError:
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
- logdir = "/".join(paths[:idx])
- ckpt = opt.resume
- else:
- assert os.path.isdir(opt.resume), opt.resume
- logdir = opt.resume.rstrip("/")
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
- print(f"logdir:{logdir}")
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
- opt.base = base_configs+opt.base
- if opt.config:
- if type(opt.config) == str:
- opt.base = [opt.config]
- else:
- opt.base = [opt.base[-1]]
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
- cli = OmegaConf.from_dotlist(unknown)
- if opt.ignore_base_data:
- for config in configs:
- if hasattr(config, "data"): del config["data"]
- config = OmegaConf.merge(*configs, cli)
- print(ckpt)
- gpu = True
- eval_mode = True
- show_config = False
- if show_config:
- print(OmegaConf.to_container(config))
- dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
- print(f"Global step: {global_step}")
- outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
- opt.top_k,
- opt.temperature))
- os.makedirs(outdir, exist_ok=True)
- print("Writing samples to ", outdir)
- for k in ["originals", "reconstructions", "samples"]:
- os.makedirs(os.path.join(outdir, k), exist_ok=True)
- run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
|