sample_conditional.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import argparse, os, sys, glob, math, time
  2. import torch
  3. import numpy as np
  4. from omegaconf import OmegaConf
  5. import streamlit as st
  6. from streamlit import caching
  7. from PIL import Image
  8. from main import instantiate_from_config, DataModuleFromConfig
  9. from torch.utils.data import DataLoader
  10. from torch.utils.data.dataloader import default_collate
  11. rescale = lambda x: (x + 1.) / 2.
  12. def bchw_to_st(x):
  13. return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
  14. def save_img(xstart, fname):
  15. I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
  16. Image.fromarray(I).save(fname)
  17. def get_interactive_image(resize=False):
  18. image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
  19. if image is not None:
  20. image = Image.open(image)
  21. if not image.mode == "RGB":
  22. image = image.convert("RGB")
  23. image = np.array(image).astype(np.uint8)
  24. print("upload image shape: {}".format(image.shape))
  25. img = Image.fromarray(image)
  26. if resize:
  27. img = img.resize((256, 256))
  28. image = np.array(img)
  29. return image
  30. def single_image_to_torch(x, permute=True):
  31. assert x is not None, "Please provide an image through the upload function"
  32. x = np.array(x)
  33. x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
  34. if permute:
  35. x = x.permute(0, 3, 1, 2)
  36. return x
  37. def pad_to_M(x, M):
  38. hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
  39. wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
  40. x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
  41. return x
  42. @torch.no_grad()
  43. def run_conditional(model, dsets):
  44. if len(dsets.datasets) > 1:
  45. split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
  46. dset = dsets.datasets[split]
  47. else:
  48. dset = next(iter(dsets.datasets.values()))
  49. batch_size = 1
  50. start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
  51. min_value=0,
  52. max_value=len(dset)-batch_size)
  53. indices = list(range(start_index, start_index+batch_size))
  54. example = default_collate([dset[i] for i in indices])
  55. x = model.get_input("image", example).to(model.device)
  56. cond_key = model.cond_stage_key
  57. c = model.get_input(cond_key, example).to(model.device)
  58. scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
  59. if scale_factor != 1.0:
  60. x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
  61. c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
  62. quant_z, z_indices = model.encode_to_z(x)
  63. quant_c, c_indices = model.encode_to_c(c)
  64. cshape = quant_z.shape
  65. xrec = model.first_stage_model.decode(quant_z)
  66. st.write("image: {}".format(x.shape))
  67. st.image(bchw_to_st(x), clamp=True, output_format="PNG")
  68. st.write("image reconstruction: {}".format(xrec.shape))
  69. st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
  70. if cond_key == "segmentation":
  71. # get image from segmentation mask
  72. num_classes = c.shape[1]
  73. c = torch.argmax(c, dim=1, keepdim=True)
  74. c = torch.nn.functional.one_hot(c, num_classes=num_classes)
  75. c = c.squeeze(1).permute(0, 3, 1, 2).float()
  76. c = model.cond_stage_model.to_rgb(c)
  77. st.write(f"{cond_key}: {tuple(c.shape)}")
  78. st.image(bchw_to_st(c), clamp=True, output_format="PNG")
  79. idx = z_indices
  80. half_sample = st.sidebar.checkbox("Image Completion", value=False)
  81. if half_sample:
  82. start = idx.shape[1]//2
  83. else:
  84. start = 0
  85. idx[:,start:] = 0
  86. idx = idx.reshape(cshape[0],cshape[2],cshape[3])
  87. start_i = start//cshape[3]
  88. start_j = start %cshape[3]
  89. if not half_sample and quant_z.shape == quant_c.shape:
  90. st.info("Setting idx to c_indices")
  91. idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
  92. cidx = c_indices
  93. cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
  94. xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
  95. st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
  96. temperature = st.number_input("Temperature", value=1.0)
  97. top_k = st.number_input("Top k", value=100)
  98. sample = st.checkbox("Sample", value=True)
  99. update_every = st.number_input("Update every", value=75)
  100. st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
  101. animate = st.checkbox("animate")
  102. if animate:
  103. import imageio
  104. outvid = "sampling.mp4"
  105. writer = imageio.get_writer(outvid, fps=25)
  106. elapsed_t = st.empty()
  107. info = st.empty()
  108. st.text("Sampled")
  109. if st.button("Sample"):
  110. output = st.empty()
  111. start_t = time.time()
  112. for i in range(start_i,cshape[2]-0):
  113. if i <= 8:
  114. local_i = i
  115. elif cshape[2]-i < 8:
  116. local_i = 16-(cshape[2]-i)
  117. else:
  118. local_i = 8
  119. for j in range(start_j,cshape[3]-0):
  120. if j <= 8:
  121. local_j = j
  122. elif cshape[3]-j < 8:
  123. local_j = 16-(cshape[3]-j)
  124. else:
  125. local_j = 8
  126. i_start = i-local_i
  127. i_end = i_start+16
  128. j_start = j-local_j
  129. j_end = j_start+16
  130. elapsed_t.text(f"Time: {time.time() - start_t} seconds")
  131. info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
  132. patch = idx[:,i_start:i_end,j_start:j_end]
  133. patch = patch.reshape(patch.shape[0],-1)
  134. cpatch = cidx[:, i_start:i_end, j_start:j_end]
  135. cpatch = cpatch.reshape(cpatch.shape[0], -1)
  136. patch = torch.cat((cpatch, patch), dim=1)
  137. logits,_ = model.transformer(patch[:,:-1])
  138. logits = logits[:, -256:, :]
  139. logits = logits.reshape(cshape[0],16,16,-1)
  140. logits = logits[:,local_i,local_j,:]
  141. logits = logits/temperature
  142. if top_k is not None:
  143. logits = model.top_k_logits(logits, top_k)
  144. # apply softmax to convert to probabilities
  145. probs = torch.nn.functional.softmax(logits, dim=-1)
  146. # sample from the distribution or take the most likely
  147. if sample:
  148. ix = torch.multinomial(probs, num_samples=1)
  149. else:
  150. _, ix = torch.topk(probs, k=1, dim=-1)
  151. idx[:,i,j] = ix
  152. if (i*cshape[3]+j)%update_every==0:
  153. xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
  154. xstart = bchw_to_st(xstart)
  155. output.image(xstart, clamp=True, output_format="PNG")
  156. if animate:
  157. writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
  158. xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
  159. xstart = bchw_to_st(xstart)
  160. output.image(xstart, clamp=True, output_format="PNG")
  161. #save_img(xstart, "full_res_sample.png")
  162. if animate:
  163. writer.close()
  164. st.video(outvid)
  165. def get_parser():
  166. parser = argparse.ArgumentParser()
  167. parser.add_argument(
  168. "-r",
  169. "--resume",
  170. type=str,
  171. nargs="?",
  172. help="load from logdir or checkpoint in logdir",
  173. )
  174. parser.add_argument(
  175. "-b",
  176. "--base",
  177. nargs="*",
  178. metavar="base_config.yaml",
  179. help="paths to base configs. Loaded from left-to-right. "
  180. "Parameters can be overwritten or added with command-line options of the form `--key value`.",
  181. default=list(),
  182. )
  183. parser.add_argument(
  184. "-c",
  185. "--config",
  186. nargs="?",
  187. metavar="single_config.yaml",
  188. help="path to single config. If specified, base configs will be ignored "
  189. "(except for the last one if left unspecified).",
  190. const=True,
  191. default="",
  192. )
  193. parser.add_argument(
  194. "--ignore_base_data",
  195. action="store_true",
  196. help="Ignore data specification from base configs. Useful if you want "
  197. "to specify a custom datasets on the command line.",
  198. )
  199. return parser
  200. def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  201. if "ckpt_path" in config.params:
  202. st.warning("Deleting the restore-ckpt path from the config...")
  203. config.params.ckpt_path = None
  204. if "downsample_cond_size" in config.params:
  205. st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
  206. config.params.downsample_cond_size = -1
  207. config.params["downsample_cond_factor"] = 0.5
  208. try:
  209. if "ckpt_path" in config.params.first_stage_config.params:
  210. config.params.first_stage_config.params.ckpt_path = None
  211. st.warning("Deleting the first-stage restore-ckpt path from the config...")
  212. if "ckpt_path" in config.params.cond_stage_config.params:
  213. config.params.cond_stage_config.params.ckpt_path = None
  214. st.warning("Deleting the cond-stage restore-ckpt path from the config...")
  215. except:
  216. pass
  217. model = instantiate_from_config(config)
  218. if sd is not None:
  219. missing, unexpected = model.load_state_dict(sd, strict=False)
  220. st.info(f"Missing Keys in State Dict: {missing}")
  221. st.info(f"Unexpected Keys in State Dict: {unexpected}")
  222. if gpu:
  223. model.cuda()
  224. if eval_mode:
  225. model.eval()
  226. return {"model": model}
  227. def get_data(config):
  228. # get data
  229. data = instantiate_from_config(config.data)
  230. data.prepare_data()
  231. data.setup()
  232. return data
  233. @st.cache(allow_output_mutation=True, suppress_st_warning=True)
  234. def load_model_and_dset(config, ckpt, gpu, eval_mode):
  235. # get data
  236. dsets = get_data(config) # calls data.config ...
  237. # now load the specified checkpoint
  238. if ckpt:
  239. pl_sd = torch.load(ckpt, map_location="cpu")
  240. global_step = pl_sd["global_step"]
  241. else:
  242. pl_sd = {"state_dict": None}
  243. global_step = None
  244. model = load_model_from_config(config.model,
  245. pl_sd["state_dict"],
  246. gpu=gpu,
  247. eval_mode=eval_mode)["model"]
  248. return dsets, model, global_step
  249. if __name__ == "__main__":
  250. sys.path.append(os.getcwd())
  251. parser = get_parser()
  252. opt, unknown = parser.parse_known_args()
  253. ckpt = None
  254. if opt.resume:
  255. if not os.path.exists(opt.resume):
  256. raise ValueError("Cannot find {}".format(opt.resume))
  257. if os.path.isfile(opt.resume):
  258. paths = opt.resume.split("/")
  259. try:
  260. idx = len(paths)-paths[::-1].index("logs")+1
  261. except ValueError:
  262. idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
  263. logdir = "/".join(paths[:idx])
  264. ckpt = opt.resume
  265. else:
  266. assert os.path.isdir(opt.resume), opt.resume
  267. logdir = opt.resume.rstrip("/")
  268. ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
  269. print(f"logdir:{logdir}")
  270. base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
  271. opt.base = base_configs+opt.base
  272. if opt.config:
  273. if type(opt.config) == str:
  274. opt.base = [opt.config]
  275. else:
  276. opt.base = [opt.base[-1]]
  277. configs = [OmegaConf.load(cfg) for cfg in opt.base]
  278. cli = OmegaConf.from_dotlist(unknown)
  279. if opt.ignore_base_data:
  280. for config in configs:
  281. if hasattr(config, "data"): del config["data"]
  282. config = OmegaConf.merge(*configs, cli)
  283. st.sidebar.text(ckpt)
  284. gs = st.sidebar.empty()
  285. gs.text(f"Global step: ?")
  286. st.sidebar.text("Options")
  287. #gpu = st.sidebar.checkbox("GPU", value=True)
  288. gpu = True
  289. #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
  290. eval_mode = True
  291. #show_config = st.sidebar.checkbox("Show Config", value=False)
  292. show_config = False
  293. if show_config:
  294. st.info("Checkpoint: {}".format(ckpt))
  295. st.json(OmegaConf.to_container(config))
  296. dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
  297. gs.text(f"Global step: {global_step}")
  298. run_conditional(model, dsets)