123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import glob
- import os
- import sys
- from itertools import product
- from pathlib import Path
- from typing import Literal, List, Optional, Tuple
- import numpy as np
- import torch
- from omegaconf import OmegaConf
- from pytorch_lightning import seed_everything
- from torch import Tensor
- from torchvision.utils import save_image
- from tqdm import tqdm
- from scripts.make_samples import get_parser, load_model_and_dset
- from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
- from taming.data.helper_types import BoundingBox, Annotation
- from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
- from taming.models.cond_transformer import Net2NetTransformer
- seed_everything(42424242)
- device: Literal['cuda', 'cpu'] = 'cuda'
- first_stage_factor = 16
- trained_on_res = 256
- def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
- assert 0 <= coord < coord_max
- coord_desired_center = (coord_window - 1) // 2
- return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
- def get_crop_coordinates(x: int, y: int) -> BoundingBox:
- WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
- x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
- y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
- w = first_stage_factor / WIDTH
- h = first_stage_factor / HEIGHT
- return x0, y0, w, h
- def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
- WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
- x0 = _helper(predict_x, WIDTH, first_stage_factor)
- y0 = _helper(predict_y, HEIGHT, first_stage_factor)
- no_images = z_indices.shape[0]
- cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
- cut_out_2 = z_indices[:, predict_y, x0:predict_x]
- return torch.cat((cut_out_1, cut_out_2), dim=1)
- @torch.no_grad()
- def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
- conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
- temperature: float, top_k: int) -> Tensor:
- x_max, y_max = desired_z_shape[1], desired_z_shape[0]
- annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
- recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
- if not recompute_conditional:
- crop_coordinates = get_crop_coordinates(0, 0)
- conditional_indices = conditional_builder.build(annotations, crop_coordinates)
- c_indices = conditional_indices.to(device).repeat(no_samples, 1)
- z_indices = torch.zeros((no_samples, 0), device=device).long()
- output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
- sample=True, top_k=top_k)
- else:
- output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
- for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
- crop_coordinates = get_crop_coordinates(predict_x, predict_y)
- z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
- conditional_indices = conditional_builder.build(annotations, crop_coordinates)
- c_indices = conditional_indices.to(device).repeat(no_samples, 1)
- new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
- output_indices[:, predict_y, predict_x] = new_index[:, -1]
- z_shape = (
- no_samples,
- model.first_stage_model.quantize.e_dim, # codebook embed_dim
- desired_z_shape[0], # z_height
- desired_z_shape[1] # z_width
- )
- x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
- x_sample = x_sample.to('cpu')
- plotter = conditional_builder.plot
- figure_size = (x_sample.shape[2], x_sample.shape[3])
- scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
- plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
- return torch.cat((x_sample, plot.unsqueeze(0)))
- def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
- if not resolution_str.count(',') == 1:
- raise ValueError("Give resolution as in 'height,width'")
- res_h, res_w = resolution_str.split(',')
- res_h = max(int(res_h), trained_on_res)
- res_w = max(int(res_w), trained_on_res)
- z_h = int(round(res_h/first_stage_factor))
- z_w = int(round(res_w/first_stage_factor))
- return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
- def add_arg_to_parser(parser):
- parser.add_argument(
- "-R",
- "--resolution",
- type=str,
- default='256,256',
- help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
- )
- parser.add_argument(
- "-C",
- "--conditional",
- type=str,
- default='objects_bbox',
- help=f"objects_bbox or objects_center_points",
- )
- parser.add_argument(
- "-N",
- "--n_samples_per_layout",
- type=int,
- default=4,
- help=f"how many samples to generate per layout",
- )
- return parser
- if __name__ == "__main__":
- sys.path.append(os.getcwd())
- parser = get_parser()
- parser = add_arg_to_parser(parser)
- opt, unknown = parser.parse_known_args()
- ckpt = None
- if opt.resume:
- if not os.path.exists(opt.resume):
- raise ValueError("Cannot find {}".format(opt.resume))
- if os.path.isfile(opt.resume):
- paths = opt.resume.split("/")
- try:
- idx = len(paths)-paths[::-1].index("logs")+1
- except ValueError:
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
- logdir = "/".join(paths[:idx])
- ckpt = opt.resume
- else:
- assert os.path.isdir(opt.resume), opt.resume
- logdir = opt.resume.rstrip("/")
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
- print(f"logdir:{logdir}")
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
- opt.base = base_configs+opt.base
- if opt.config:
- if type(opt.config) == str:
- opt.base = [opt.config]
- else:
- opt.base = [opt.base[-1]]
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
- cli = OmegaConf.from_dotlist(unknown)
- if opt.ignore_base_data:
- for config in configs:
- if hasattr(config, "data"):
- del config["data"]
- config = OmegaConf.merge(*configs, cli)
- desired_z_shape, desired_resolution = get_resolution(opt.resolution)
- conditional = opt.conditional
- print(ckpt)
- gpu = True
- eval_mode = True
- show_config = False
- if show_config:
- print(OmegaConf.to_container(config))
- dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
- print(f"Global step: {global_step}")
- data_loader = dsets.val_dataloader()
- print(dsets.datasets["validation"].conditional_builders)
- conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
- outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
- outdir.mkdir(exist_ok=True, parents=True)
- print("Writing samples to ", outdir)
- p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
- for batch_no, batch in p_bar_1:
- save_img: Optional[Tensor] = None
- for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
- imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
- opt.n_samples_per_layout, opt.temperature, opt.top_k)
- save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
|