123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- import os
- import torch
- import numpy as np
- from tqdm import trange
- from PIL import Image
- def get_state(gpu):
- import torch
- midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
- if gpu:
- midas.cuda()
- midas.eval()
- midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
- transform = midas_transforms.default_transform
- state = {"model": midas,
- "transform": transform}
- return state
- def depth_to_rgba(x):
- assert x.dtype == np.float32
- assert len(x.shape) == 2
- y = x.copy()
- y.dtype = np.uint8
- y = y.reshape(x.shape+(4,))
- return np.ascontiguousarray(y)
- def rgba_to_depth(x):
- assert x.dtype == np.uint8
- assert len(x.shape) == 3 and x.shape[2] == 4
- y = x.copy()
- y.dtype = np.float32
- y = y.reshape(x.shape[:2])
- return np.ascontiguousarray(y)
- def run(x, state):
- model = state["model"]
- transform = state["transform"]
- hw = x.shape[:2]
- with torch.no_grad():
- prediction = model(transform((x + 1.0) * 127.5).cuda())
- prediction = torch.nn.functional.interpolate(
- prediction.unsqueeze(1),
- size=hw,
- mode="bicubic",
- align_corners=False,
- ).squeeze()
- output = prediction.cpu().numpy()
- return output
- def get_filename(relpath, level=-2):
- # save class folder structure and filename:
- fn = relpath.split(os.sep)[level:]
- folder = fn[-2]
- file = fn[-1].split('.')[0]
- return folder, file
- def save_depth(dataset, path, debug=False):
- os.makedirs(path)
- N = len(dset)
- if debug:
- N = 10
- state = get_state(gpu=True)
- for idx in trange(N, desc="Data"):
- ex = dataset[idx]
- image, relpath = ex["image"], ex["relpath"]
- folder, filename = get_filename(relpath)
- # prepare
- folderabspath = os.path.join(path, folder)
- os.makedirs(folderabspath, exist_ok=True)
- savepath = os.path.join(folderabspath, filename)
- # run model
- xout = run(image, state)
- I = depth_to_rgba(xout)
- Image.fromarray(I).save("{}.png".format(savepath))
- if __name__ == "__main__":
- from taming.data.imagenet import ImageNetTrain, ImageNetValidation
- out = "data/imagenet_depth"
- if not os.path.exists(out):
- print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
- "(be prepared that the output size will be larger than ImageNet itself).")
- exit(1)
- # go
- dset = ImageNetValidation()
- abspath = os.path.join(out, "val")
- if os.path.exists(abspath):
- print("{} exists - not doing anything.".format(abspath))
- else:
- print("preparing {}".format(abspath))
- save_depth(dset, abspath)
- print("done with validation split")
- dset = ImageNetTrain()
- abspath = os.path.join(out, "train")
- if os.path.exists(abspath):
- print("{} exists - not doing anything.".format(abspath))
- else:
- print("preparing {}".format(abspath))
- save_depth(dset, abspath)
- print("done with train split")
- print("done done.")
|