gan / taming-transformers /scripts /extract_depth.py
danurahul's picture
Add application file
d4bc11a
import os
import torch
import numpy as np
from tqdm import trange
from PIL import Image
def get_state(gpu):
import torch
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
if gpu:
midas.cuda()
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.default_transform
state = {"model": midas,
"transform": transform}
return state
def depth_to_rgba(x):
assert x.dtype == np.float32
assert len(x.shape) == 2
y = x.copy()
y.dtype = np.uint8
y = y.reshape(x.shape+(4,))
return np.ascontiguousarray(y)
def rgba_to_depth(x):
assert x.dtype == np.uint8
assert len(x.shape) == 3 and x.shape[2] == 4
y = x.copy()
y.dtype = np.float32
y = y.reshape(x.shape[:2])
return np.ascontiguousarray(y)
def run(x, state):
model = state["model"]
transform = state["transform"]
hw = x.shape[:2]
with torch.no_grad():
prediction = model(transform((x + 1.0) * 127.5).cuda())
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=hw,
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
return output
def get_filename(relpath, level=-2):
# save class folder structure and filename:
fn = relpath.split(os.sep)[level:]
folder = fn[-2]
file = fn[-1].split('.')[0]
return folder, file
def save_depth(dataset, path, debug=False):
os.makedirs(path)
N = len(dset)
if debug:
N = 10
state = get_state(gpu=True)
for idx in trange(N, desc="Data"):
ex = dataset[idx]
image, relpath = ex["image"], ex["relpath"]
folder, filename = get_filename(relpath)
# prepare
folderabspath = os.path.join(path, folder)
os.makedirs(folderabspath, exist_ok=True)
savepath = os.path.join(folderabspath, filename)
# run model
xout = run(image, state)
I = depth_to_rgba(xout)
Image.fromarray(I).save("{}.png".format(savepath))
if __name__ == "__main__":
from taming.data.imagenet import ImageNetTrain, ImageNetValidation
out = "data/imagenet_depth"
if not os.path.exists(out):
print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
"(be prepared that the output size will be larger than ImageNet itself).")
exit(1)
# go
dset = ImageNetValidation()
abspath = os.path.join(out, "val")
if os.path.exists(abspath):
print("{} exists - not doing anything.".format(abspath))
else:
print("preparing {}".format(abspath))
save_depth(dset, abspath)
print("done with validation split")
dset = ImageNetTrain()
abspath = os.path.join(out, "train")
if os.path.exists(abspath):
print("{} exists - not doing anything.".format(abspath))
else:
print("preparing {}".format(abspath))
save_depth(dset, abspath)
print("done with train split")
print("done done.")