sample_fast.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import argparse, os, sys, glob
  2. import torch
  3. import time
  4. import numpy as np
  5. from omegaconf import OmegaConf
  6. from PIL import Image
  7. from tqdm import tqdm, trange
  8. from einops import repeat
  9. from main import instantiate_from_config
  10. from taming.modules.transformer.mingpt import sample_with_past
  11. rescale = lambda x: (x + 1.) / 2.
  12. def chw_to_pillow(x):
  13. return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
  14. @torch.no_grad()
  15. def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
  16. dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
  17. log = dict()
  18. assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
  19. qzshape = [batch_size, dim_z, h, w]
  20. assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
  21. c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
  22. t1 = time.time()
  23. index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
  24. sample_logits=True, top_k=top_k, callback=callback,
  25. temperature=temperature, top_p=top_p)
  26. if verbose_time:
  27. sampling_time = time.time() - t1
  28. print(f"Full sampling takes about {sampling_time:.2f} seconds.")
  29. x_sample = model.decode_to_img(index_sample, qzshape)
  30. log["samples"] = x_sample
  31. log["class_label"] = c_indices
  32. return log
  33. @torch.no_grad()
  34. def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
  35. dim_z=256, h=16, w=16, verbose_time=False):
  36. log = dict()
  37. qzshape = [batch_size, dim_z, h, w]
  38. assert model.be_unconditional, 'Expecting an unconditional model.'
  39. c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
  40. t1 = time.time()
  41. index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
  42. sample_logits=True, top_k=top_k, callback=callback,
  43. temperature=temperature, top_p=top_p)
  44. if verbose_time:
  45. sampling_time = time.time() - t1
  46. print(f"Full sampling takes about {sampling_time:.2f} seconds.")
  47. x_sample = model.decode_to_img(index_sample, qzshape)
  48. log["samples"] = x_sample
  49. return log
  50. @torch.no_grad()
  51. def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
  52. given_classes=None, top_p=None):
  53. batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
  54. if not unconditional:
  55. assert given_classes is not None
  56. print("Running in pure class-conditional sampling mode. I will produce "
  57. f"{num_samples} samples for each of the {len(given_classes)} classes, "
  58. f"i.e. {num_samples*len(given_classes)} in total.")
  59. for class_label in tqdm(given_classes, desc="Classes"):
  60. for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
  61. if bs == 0: break
  62. logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
  63. temperature=temperature, top_k=top_k, top_p=top_p)
  64. save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
  65. else:
  66. print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
  67. for n, bs in tqdm(enumerate(batches), desc="Sampling"):
  68. if bs == 0: break
  69. logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
  70. save_from_logs(logs, logdir, base_count=n * batch_size)
  71. def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
  72. xx = logs[key]
  73. for i, x in enumerate(xx):
  74. x = chw_to_pillow(x)
  75. count = base_count + i
  76. if cond_key is None:
  77. x.save(os.path.join(logdir, f"{count:06}.png"))
  78. else:
  79. condlabel = cond_key[i]
  80. if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
  81. os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
  82. x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
  83. def get_parser():
  84. def str2bool(v):
  85. if isinstance(v, bool):
  86. return v
  87. if v.lower() in ("yes", "true", "t", "y", "1"):
  88. return True
  89. elif v.lower() in ("no", "false", "f", "n", "0"):
  90. return False
  91. else:
  92. raise argparse.ArgumentTypeError("Boolean value expected.")
  93. parser = argparse.ArgumentParser()
  94. parser.add_argument(
  95. "-r",
  96. "--resume",
  97. type=str,
  98. nargs="?",
  99. help="load from logdir or checkpoint in logdir",
  100. )
  101. parser.add_argument(
  102. "-o",
  103. "--outdir",
  104. type=str,
  105. nargs="?",
  106. help="path where the samples will be logged to.",
  107. default=""
  108. )
  109. parser.add_argument(
  110. "-b",
  111. "--base",
  112. nargs="*",
  113. metavar="base_config.yaml",
  114. help="paths to base configs. Loaded from left-to-right. "
  115. "Parameters can be overwritten or added with command-line options of the form `--key value`.",
  116. default=list(),
  117. )
  118. parser.add_argument(
  119. "-n",
  120. "--num_samples",
  121. type=int,
  122. nargs="?",
  123. help="num_samples to draw",
  124. default=50000
  125. )
  126. parser.add_argument(
  127. "--batch_size",
  128. type=int,
  129. nargs="?",
  130. help="the batch size",
  131. default=25
  132. )
  133. parser.add_argument(
  134. "-k",
  135. "--top_k",
  136. type=int,
  137. nargs="?",
  138. help="top-k value to sample with",
  139. default=250,
  140. )
  141. parser.add_argument(
  142. "-t",
  143. "--temperature",
  144. type=float,
  145. nargs="?",
  146. help="temperature value to sample with",
  147. default=1.0
  148. )
  149. parser.add_argument(
  150. "-p",
  151. "--top_p",
  152. type=float,
  153. nargs="?",
  154. help="top-p value to sample with",
  155. default=1.0
  156. )
  157. parser.add_argument(
  158. "--classes",
  159. type=str,
  160. nargs="?",
  161. help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
  162. default="imagenet"
  163. )
  164. return parser
  165. def load_model_from_config(config, sd, gpu=True, eval_mode=True):
  166. model = instantiate_from_config(config)
  167. if sd is not None:
  168. model.load_state_dict(sd)
  169. if gpu:
  170. model.cuda()
  171. if eval_mode:
  172. model.eval()
  173. return {"model": model}
  174. def load_model(config, ckpt, gpu, eval_mode):
  175. # load the specified checkpoint
  176. if ckpt:
  177. pl_sd = torch.load(ckpt, map_location="cpu")
  178. global_step = pl_sd["global_step"]
  179. print(f"loaded model from global step {global_step}.")
  180. else:
  181. pl_sd = {"state_dict": None}
  182. global_step = None
  183. model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
  184. return model, global_step
  185. if __name__ == "__main__":
  186. sys.path.append(os.getcwd())
  187. parser = get_parser()
  188. opt, unknown = parser.parse_known_args()
  189. assert opt.resume
  190. ckpt = None
  191. if not os.path.exists(opt.resume):
  192. raise ValueError("Cannot find {}".format(opt.resume))
  193. if os.path.isfile(opt.resume):
  194. paths = opt.resume.split("/")
  195. try:
  196. idx = len(paths)-paths[::-1].index("logs")+1
  197. except ValueError:
  198. idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
  199. logdir = "/".join(paths[:idx])
  200. ckpt = opt.resume
  201. else:
  202. assert os.path.isdir(opt.resume), opt.resume
  203. logdir = opt.resume.rstrip("/")
  204. ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
  205. base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
  206. opt.base = base_configs+opt.base
  207. configs = [OmegaConf.load(cfg) for cfg in opt.base]
  208. cli = OmegaConf.from_dotlist(unknown)
  209. config = OmegaConf.merge(*configs, cli)
  210. model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
  211. if opt.outdir:
  212. print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
  213. logdir = opt.outdir
  214. if opt.classes == "imagenet":
  215. given_classes = [i for i in range(1000)]
  216. else:
  217. cls_str = opt.classes
  218. assert not cls_str.endswith(","), 'class string should not end with a ","'
  219. given_classes = [int(c) for c in cls_str.split(",")]
  220. logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
  221. f"{global_step}")
  222. print(f"Logging to {logdir}")
  223. os.makedirs(logdir, exist_ok=True)
  224. run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
  225. given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
  226. print("done.")