123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355 |
- import argparse, os, sys, glob, math, time
- import torch
- import numpy as np
- from omegaconf import OmegaConf
- import streamlit as st
- from streamlit import caching
- from PIL import Image
- from main import instantiate_from_config, DataModuleFromConfig
- from torch.utils.data import DataLoader
- from torch.utils.data.dataloader import default_collate
- rescale = lambda x: (x + 1.) / 2.
- def bchw_to_st(x):
- return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
- def save_img(xstart, fname):
- I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
- Image.fromarray(I).save(fname)
- def get_interactive_image(resize=False):
- image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
- if image is not None:
- image = Image.open(image)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- image = np.array(image).astype(np.uint8)
- print("upload image shape: {}".format(image.shape))
- img = Image.fromarray(image)
- if resize:
- img = img.resize((256, 256))
- image = np.array(img)
- return image
- def single_image_to_torch(x, permute=True):
- assert x is not None, "Please provide an image through the upload function"
- x = np.array(x)
- x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
- if permute:
- x = x.permute(0, 3, 1, 2)
- return x
- def pad_to_M(x, M):
- hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
- wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
- x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
- return x
- @torch.no_grad()
- def run_conditional(model, dsets):
- if len(dsets.datasets) > 1:
- split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
- dset = dsets.datasets[split]
- else:
- dset = next(iter(dsets.datasets.values()))
- batch_size = 1
- start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
- min_value=0,
- max_value=len(dset)-batch_size)
- indices = list(range(start_index, start_index+batch_size))
- example = default_collate([dset[i] for i in indices])
- x = model.get_input("image", example).to(model.device)
- cond_key = model.cond_stage_key
- c = model.get_input(cond_key, example).to(model.device)
- scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
- if scale_factor != 1.0:
- x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
- c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
- quant_z, z_indices = model.encode_to_z(x)
- quant_c, c_indices = model.encode_to_c(c)
- cshape = quant_z.shape
- xrec = model.first_stage_model.decode(quant_z)
- st.write("image: {}".format(x.shape))
- st.image(bchw_to_st(x), clamp=True, output_format="PNG")
- st.write("image reconstruction: {}".format(xrec.shape))
- st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
- if cond_key == "segmentation":
- # get image from segmentation mask
- num_classes = c.shape[1]
- c = torch.argmax(c, dim=1, keepdim=True)
- c = torch.nn.functional.one_hot(c, num_classes=num_classes)
- c = c.squeeze(1).permute(0, 3, 1, 2).float()
- c = model.cond_stage_model.to_rgb(c)
- st.write(f"{cond_key}: {tuple(c.shape)}")
- st.image(bchw_to_st(c), clamp=True, output_format="PNG")
- idx = z_indices
- half_sample = st.sidebar.checkbox("Image Completion", value=False)
- if half_sample:
- start = idx.shape[1]//2
- else:
- start = 0
- idx[:,start:] = 0
- idx = idx.reshape(cshape[0],cshape[2],cshape[3])
- start_i = start//cshape[3]
- start_j = start %cshape[3]
- if not half_sample and quant_z.shape == quant_c.shape:
- st.info("Setting idx to c_indices")
- idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
- cidx = c_indices
- cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
- xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
- st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
- temperature = st.number_input("Temperature", value=1.0)
- top_k = st.number_input("Top k", value=100)
- sample = st.checkbox("Sample", value=True)
- update_every = st.number_input("Update every", value=75)
- st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
- animate = st.checkbox("animate")
- if animate:
- import imageio
- outvid = "sampling.mp4"
- writer = imageio.get_writer(outvid, fps=25)
- elapsed_t = st.empty()
- info = st.empty()
- st.text("Sampled")
- if st.button("Sample"):
- output = st.empty()
- start_t = time.time()
- for i in range(start_i,cshape[2]-0):
- if i <= 8:
- local_i = i
- elif cshape[2]-i < 8:
- local_i = 16-(cshape[2]-i)
- else:
- local_i = 8
- for j in range(start_j,cshape[3]-0):
- if j <= 8:
- local_j = j
- elif cshape[3]-j < 8:
- local_j = 16-(cshape[3]-j)
- else:
- local_j = 8
- i_start = i-local_i
- i_end = i_start+16
- j_start = j-local_j
- j_end = j_start+16
- elapsed_t.text(f"Time: {time.time() - start_t} seconds")
- info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
- patch = idx[:,i_start:i_end,j_start:j_end]
- patch = patch.reshape(patch.shape[0],-1)
- cpatch = cidx[:, i_start:i_end, j_start:j_end]
- cpatch = cpatch.reshape(cpatch.shape[0], -1)
- patch = torch.cat((cpatch, patch), dim=1)
- logits,_ = model.transformer(patch[:,:-1])
- logits = logits[:, -256:, :]
- logits = logits.reshape(cshape[0],16,16,-1)
- logits = logits[:,local_i,local_j,:]
- logits = logits/temperature
- if top_k is not None:
- logits = model.top_k_logits(logits, top_k)
- # apply softmax to convert to probabilities
- probs = torch.nn.functional.softmax(logits, dim=-1)
- # sample from the distribution or take the most likely
- if sample:
- ix = torch.multinomial(probs, num_samples=1)
- else:
- _, ix = torch.topk(probs, k=1, dim=-1)
- idx[:,i,j] = ix
- if (i*cshape[3]+j)%update_every==0:
- xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
- xstart = bchw_to_st(xstart)
- output.image(xstart, clamp=True, output_format="PNG")
- if animate:
- writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
- xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
- xstart = bchw_to_st(xstart)
- output.image(xstart, clamp=True, output_format="PNG")
- #save_img(xstart, "full_res_sample.png")
- if animate:
- writer.close()
- st.video(outvid)
- def get_parser():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-r",
- "--resume",
- type=str,
- nargs="?",
- help="load from logdir or checkpoint in logdir",
- )
- parser.add_argument(
- "-b",
- "--base",
- nargs="*",
- metavar="base_config.yaml",
- help="paths to base configs. Loaded from left-to-right. "
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
- default=list(),
- )
- parser.add_argument(
- "-c",
- "--config",
- nargs="?",
- metavar="single_config.yaml",
- help="path to single config. If specified, base configs will be ignored "
- "(except for the last one if left unspecified).",
- const=True,
- default="",
- )
- parser.add_argument(
- "--ignore_base_data",
- action="store_true",
- help="Ignore data specification from base configs. Useful if you want "
- "to specify a custom datasets on the command line.",
- )
- return parser
- def load_model_from_config(config, sd, gpu=True, eval_mode=True):
- if "ckpt_path" in config.params:
- st.warning("Deleting the restore-ckpt path from the config...")
- config.params.ckpt_path = None
- if "downsample_cond_size" in config.params:
- st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
- config.params.downsample_cond_size = -1
- config.params["downsample_cond_factor"] = 0.5
- try:
- if "ckpt_path" in config.params.first_stage_config.params:
- config.params.first_stage_config.params.ckpt_path = None
- st.warning("Deleting the first-stage restore-ckpt path from the config...")
- if "ckpt_path" in config.params.cond_stage_config.params:
- config.params.cond_stage_config.params.ckpt_path = None
- st.warning("Deleting the cond-stage restore-ckpt path from the config...")
- except:
- pass
- model = instantiate_from_config(config)
- if sd is not None:
- missing, unexpected = model.load_state_dict(sd, strict=False)
- st.info(f"Missing Keys in State Dict: {missing}")
- st.info(f"Unexpected Keys in State Dict: {unexpected}")
- if gpu:
- model.cuda()
- if eval_mode:
- model.eval()
- return {"model": model}
- def get_data(config):
- # get data
- data = instantiate_from_config(config.data)
- data.prepare_data()
- data.setup()
- return data
- @st.cache(allow_output_mutation=True, suppress_st_warning=True)
- def load_model_and_dset(config, ckpt, gpu, eval_mode):
- # get data
- dsets = get_data(config) # calls data.config ...
- # now load the specified checkpoint
- if ckpt:
- pl_sd = torch.load(ckpt, map_location="cpu")
- global_step = pl_sd["global_step"]
- else:
- pl_sd = {"state_dict": None}
- global_step = None
- model = load_model_from_config(config.model,
- pl_sd["state_dict"],
- gpu=gpu,
- eval_mode=eval_mode)["model"]
- return dsets, model, global_step
- if __name__ == "__main__":
- sys.path.append(os.getcwd())
- parser = get_parser()
- opt, unknown = parser.parse_known_args()
- ckpt = None
- if opt.resume:
- if not os.path.exists(opt.resume):
- raise ValueError("Cannot find {}".format(opt.resume))
- if os.path.isfile(opt.resume):
- paths = opt.resume.split("/")
- try:
- idx = len(paths)-paths[::-1].index("logs")+1
- except ValueError:
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
- logdir = "/".join(paths[:idx])
- ckpt = opt.resume
- else:
- assert os.path.isdir(opt.resume), opt.resume
- logdir = opt.resume.rstrip("/")
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
- print(f"logdir:{logdir}")
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
- opt.base = base_configs+opt.base
- if opt.config:
- if type(opt.config) == str:
- opt.base = [opt.config]
- else:
- opt.base = [opt.base[-1]]
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
- cli = OmegaConf.from_dotlist(unknown)
- if opt.ignore_base_data:
- for config in configs:
- if hasattr(config, "data"): del config["data"]
- config = OmegaConf.merge(*configs, cli)
- st.sidebar.text(ckpt)
- gs = st.sidebar.empty()
- gs.text(f"Global step: ?")
- st.sidebar.text("Options")
- #gpu = st.sidebar.checkbox("GPU", value=True)
- gpu = True
- #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
- eval_mode = True
- #show_config = st.sidebar.checkbox("Show Config", value=False)
- show_config = False
- if show_config:
- st.info("Checkpoint: {}".format(ckpt))
- st.json(OmegaConf.to_container(config))
- dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
- gs.text(f"Global step: {global_step}")
- run_conditional(model, dsets)
|