|
import os |
|
import torch |
|
import numpy as np |
|
from tqdm import trange |
|
from PIL import Image |
|
|
|
|
|
def get_state(gpu): |
|
import torch |
|
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") |
|
if gpu: |
|
midas.cuda() |
|
midas.eval() |
|
|
|
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") |
|
transform = midas_transforms.default_transform |
|
|
|
state = {"model": midas, |
|
"transform": transform} |
|
return state |
|
|
|
|
|
def depth_to_rgba(x): |
|
assert x.dtype == np.float32 |
|
assert len(x.shape) == 2 |
|
y = x.copy() |
|
y.dtype = np.uint8 |
|
y = y.reshape(x.shape+(4,)) |
|
return np.ascontiguousarray(y) |
|
|
|
|
|
def rgba_to_depth(x): |
|
assert x.dtype == np.uint8 |
|
assert len(x.shape) == 3 and x.shape[2] == 4 |
|
y = x.copy() |
|
y.dtype = np.float32 |
|
y = y.reshape(x.shape[:2]) |
|
return np.ascontiguousarray(y) |
|
|
|
|
|
def run(x, state): |
|
model = state["model"] |
|
transform = state["transform"] |
|
hw = x.shape[:2] |
|
with torch.no_grad(): |
|
prediction = model(transform((x + 1.0) * 127.5).cuda()) |
|
prediction = torch.nn.functional.interpolate( |
|
prediction.unsqueeze(1), |
|
size=hw, |
|
mode="bicubic", |
|
align_corners=False, |
|
).squeeze() |
|
output = prediction.cpu().numpy() |
|
return output |
|
|
|
|
|
def get_filename(relpath, level=-2): |
|
|
|
fn = relpath.split(os.sep)[level:] |
|
folder = fn[-2] |
|
file = fn[-1].split('.')[0] |
|
return folder, file |
|
|
|
|
|
def save_depth(dataset, path, debug=False): |
|
os.makedirs(path) |
|
N = len(dset) |
|
if debug: |
|
N = 10 |
|
state = get_state(gpu=True) |
|
for idx in trange(N, desc="Data"): |
|
ex = dataset[idx] |
|
image, relpath = ex["image"], ex["relpath"] |
|
folder, filename = get_filename(relpath) |
|
|
|
folderabspath = os.path.join(path, folder) |
|
os.makedirs(folderabspath, exist_ok=True) |
|
savepath = os.path.join(folderabspath, filename) |
|
|
|
xout = run(image, state) |
|
I = depth_to_rgba(xout) |
|
Image.fromarray(I).save("{}.png".format(savepath)) |
|
|
|
|
|
if __name__ == "__main__": |
|
from taming.data.imagenet import ImageNetTrain, ImageNetValidation |
|
out = "data/imagenet_depth" |
|
if not os.path.exists(out): |
|
print("Please create a folder or symlink '{}' to extract depth data ".format(out) + |
|
"(be prepared that the output size will be larger than ImageNet itself).") |
|
exit(1) |
|
|
|
|
|
dset = ImageNetValidation() |
|
abspath = os.path.join(out, "val") |
|
if os.path.exists(abspath): |
|
print("{} exists - not doing anything.".format(abspath)) |
|
else: |
|
print("preparing {}".format(abspath)) |
|
save_depth(dset, abspath) |
|
print("done with validation split") |
|
|
|
dset = ImageNetTrain() |
|
abspath = os.path.join(out, "train") |
|
if os.path.exists(abspath): |
|
print("{} exists - not doing anything.".format(abspath)) |
|
else: |
|
print("preparing {}".format(abspath)) |
|
save_depth(dset, abspath) |
|
print("done with train split") |
|
|
|
print("done done.") |
|
|