vmem / extern /CUT3R /src /train.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
# --------------------------------------------------------
# training code for CUT3R
# --------------------------------------------------------
# References:
# DUSt3R: https://github.com/naver/dust3r
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import sys
import time
import math
from collections import defaultdict
from pathlib import Path
from typing import Sized
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
from dust3r.model import (
PreTrainedModel,
ARCroco3DStereo,
ARCroco3DStereoConfig,
inf,
strip_module,
) # noqa: F401, needed when loading the model
from dust3r.datasets import get_data_loader
from dust3r.losses import * # noqa: F401, needed when loading the model
from dust3r.inference import loss_of_one_batch, loss_of_one_batch_tbptt # noqa
from dust3r.viz import colorize
from dust3r.utils.render import get_render_results
import dust3r.utils.path_to_croco # noqa: F401
import croco.utils.misc as misc # noqa
from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa
import hydra
from omegaconf import OmegaConf
import logging
import pathlib
from tqdm import tqdm
import random
import builtins
import shutil
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs
from accelerate.logging import get_logger
from datetime import timedelta
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy("file_system")
printer = get_logger(__name__, log_level="DEBUG")
def setup_for_distributed(accelerator: Accelerator):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
force = force or (accelerator.num_processes > 8)
if accelerator.is_main_process or force:
now = datetime.datetime.now().time()
builtin_print("[{}] ".format(now), end="") # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def save_current_code(outdir):
now = datetime.datetime.now() # current date and time
date_time = now.strftime("%m_%d-%H:%M:%S")
src_dir = "."
dst_dir = os.path.join(outdir, "code", "{}".format(date_time))
shutil.copytree(
src_dir,
dst_dir,
ignore=shutil.ignore_patterns(
".vscode*",
"assets*",
"example*",
"checkpoints*",
"OLD*",
"logs*",
"out*",
"runs*",
"*.png",
"*.mp4",
"*__pycache__*",
"*.git*",
"*.idea*",
"*.zip",
"*.jpg",
),
dirs_exist_ok=True,
)
return dst_dir
def train(args):
accelerator = Accelerator(
gradient_accumulation_steps=args.accum_iter,
mixed_precision="bf16",
kwargs_handlers=[
DistributedDataParallelKwargs(find_unused_parameters=True),
InitProcessGroupKwargs(timeout=timedelta(seconds=6000)),
],
)
device = accelerator.device
setup_for_distributed(accelerator)
printer.info("output_dir: " + args.output_dir)
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
if accelerator.is_main_process:
dst_dir = save_current_code(outdir=args.output_dir)
printer.info(f"Saving current code to {dst_dir}")
# auto resume
if not args.resume:
last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth")
args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
printer.info("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
# fix the seed
seed = args.seed + accelerator.state.process_index
printer.info(
f"Setting seed to {seed} for process {accelerator.state.process_index}"
)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = args.benchmark
# training dataset and loader
printer.info("Building train dataset %s", args.train_dataset)
# dataset and loader
data_loader_train = build_dataset(
args.train_dataset,
args.batch_size,
args.num_workers,
accelerator=accelerator,
test=False,
fixed_length=args.fixed_length
)
printer.info("Building test dataset %s", args.test_dataset)
data_loader_test = {
dataset.split("(")[0]: build_dataset(
dataset,
args.batch_size,
args.num_workers,
accelerator=accelerator,
test=True,
fixed_length=True
)
for dataset in args.test_dataset.split("+")
}
# model
printer.info("Loading model: %s", args.model)
model: PreTrainedModel = eval(args.model)
printer.info(f"All model parameters: {sum(p.numel() for p in model.parameters())}")
printer.info(
f"Encoder parameters: {sum(p.numel() for p in model.enc_blocks.parameters())}"
)
printer.info(
f"Decoder parameters: {sum(p.numel() for p in model.dec_blocks.parameters())}"
)
printer.info(f">> Creating train criterion = {args.train_criterion}")
train_criterion = eval(args.train_criterion).to(device)
printer.info(
f">> Creating test criterion = {args.test_criterion or args.train_criterion}"
)
test_criterion = eval(args.test_criterion or args.criterion).to(device)
model.to(device)
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
if args.long_context:
model.fixed_input_length = False
if args.pretrained and not args.resume:
printer.info(f"Loading pretrained: {args.pretrained}")
ckpt = torch.load(args.pretrained, map_location=device)
load_only_encoder = getattr(args, "load_only_encoder", False)
if load_only_encoder:
filtered_state_dict = {
k: v
for k, v in ckpt["model"].items()
if "enc_blocks" in k or "patch_embed" in k
}
printer.info(
model.load_state_dict(strip_module(filtered_state_dict), strict=False)
)
else:
printer.info(
model.load_state_dict(strip_module(ckpt["model"]), strict=False)
)
del ckpt # in case it occupies memory
# # following timm: set wd as 0 for bias and norm layers
param_groups = misc.get_parameter_groups(model, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
# print(optimizer)
loss_scaler = NativeScaler(accelerator=accelerator)
accelerator.even_batches = False
optimizer, model, data_loader_train = accelerator.prepare(
optimizer, model, data_loader_train
)
def write_log_stats(epoch, train_stats, test_stats):
if accelerator.is_main_process:
if log_writer is not None:
log_writer.flush()
log_stats = dict(
epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()}
)
for test_name in data_loader_test:
if test_name not in test_stats:
continue
log_stats.update(
{test_name + "_" + k: v for k, v in test_stats[test_name].items()}
)
with open(
os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
) as f:
f.write(json.dumps(log_stats) + "\n")
def save_model(epoch, fname, best_so_far):
misc.save_model(
accelerator=accelerator,
args=args,
model_without_ddp=model,
optimizer=optimizer,
loss_scaler=loss_scaler,
epoch=epoch,
fname=fname,
best_so_far=best_so_far,
)
best_so_far = misc.load_model(
args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler
)
if best_so_far is None:
best_so_far = float("inf")
log_writer = (
SummaryWriter(log_dir=args.output_dir) if accelerator.is_main_process else None
)
printer.info(f"Start training for {args.epochs} epochs")
start_time = time.time()
train_stats = test_stats = {}
for epoch in range(args.start_epoch, args.epochs + 1):
# Save immediately the last checkpoint
if epoch > args.start_epoch:
if (
args.save_freq
and np.allclose(epoch / args.save_freq, int(epoch / args.save_freq))
or epoch == args.epochs
):
save_model(epoch - 1, "last", best_so_far)
# Test on multiple datasets
new_best = False
if epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0:
test_stats = {}
for test_name, testset in data_loader_test.items():
stats = test_one_epoch(
model,
test_criterion,
testset,
accelerator,
device,
epoch,
log_writer=log_writer,
args=args,
prefix=test_name,
)
test_stats[test_name] = stats
# Save best of all
if stats["loss_med"] < best_so_far:
best_so_far = stats["loss_med"]
new_best = True
# Save more stuff
write_log_stats(epoch, train_stats, test_stats)
if epoch > args.start_epoch:
if args.keep_freq and epoch % args.keep_freq == 0:
save_model(epoch - 1, str(epoch), best_so_far)
if new_best:
save_model(epoch - 1, "best", best_so_far)
if epoch >= args.epochs:
break # exit after writing last test to disk
# Train
train_stats = train_one_epoch(
model,
train_criterion,
data_loader_train,
optimizer,
accelerator,
epoch,
loss_scaler,
log_writer=log_writer,
args=args,
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
printer.info("Training time {}".format(total_time_str))
save_final_model(accelerator, args, args.epochs, model, best_so_far=best_so_far)
def save_final_model(accelerator, args, epoch, model_without_ddp, best_so_far=None):
output_dir = Path(args.output_dir)
checkpoint_path = output_dir / "checkpoint-final.pth"
to_save = {
"args": args,
"model": (
model_without_ddp
if isinstance(model_without_ddp, dict)
else model_without_ddp.cpu().state_dict()
),
"epoch": epoch,
}
if best_so_far is not None:
to_save["best_so_far"] = best_so_far
printer.info(f">> Saving model to {checkpoint_path} ...")
misc.save_on_master(accelerator, to_save, checkpoint_path)
def build_dataset(dataset, batch_size, num_workers, accelerator, test=False, fixed_length=False):
split = ["Train", "Test"][test]
printer.info(f"Building {split} Data loader for dataset: {dataset}")
loader = get_data_loader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_mem=True,
shuffle=not (test),
drop_last=not (test),
accelerator=accelerator,
fixed_length=fixed_length
)
return loader
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Sized,
optimizer: torch.optim.Optimizer,
accelerator: Accelerator,
epoch: int,
loss_scaler,
args,
log_writer=None,
):
assert torch.backends.cuda.matmul.allow_tf32 == True
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
accum_iter = args.accum_iter
def save_model(epoch, fname, best_so_far):
misc.save_model(
accelerator=accelerator,
args=args,
model_without_ddp=model,
optimizer=optimizer,
loss_scaler=loss_scaler,
epoch=epoch,
fname=fname,
best_so_far=best_so_far,
)
if log_writer is not None:
printer.info("log_dir: {}".format(log_writer.log_dir))
if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
data_loader.dataset.set_epoch(epoch)
if (
hasattr(data_loader, "batch_sampler")
and hasattr(data_loader.batch_sampler, "batch_sampler")
and hasattr(data_loader.batch_sampler.batch_sampler, "set_epoch")
):
data_loader.batch_sampler.batch_sampler.set_epoch(epoch)
optimizer.zero_grad()
for data_iter_step, batch in enumerate(
metric_logger.log_every(data_loader, args.print_freq, accelerator, header)
):
with accelerator.accumulate(model):
epoch_f = epoch + data_iter_step / len(data_loader)
step = int(epoch_f * len(data_loader))
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
misc.adjust_learning_rate(optimizer, epoch_f, args)
if not args.long_context:
result = loss_of_one_batch(
batch,
model,
criterion,
accelerator,
symmetrize_batch=False,
use_amp=bool(args.amp),
)
else:
result = loss_of_one_batch_tbptt(
batch,
model,
criterion,
chunk_size=4,
loss_scaler=loss_scaler,
optimizer=optimizer,
accelerator=accelerator,
symmetrize_batch=False,
use_amp=bool(args.amp),
)
loss, loss_details = result["loss"] # criterion returns two values
loss_value = float(loss)
if not math.isfinite(loss_value):
print(
f"Loss is {loss_value}, stopping training, loss details: {loss_details}"
)
sys.exit(1)
if not result.get("already_backprop", False):
loss_scaler(
loss,
optimizer,
parameters=model.parameters(),
update_grad=True,
clip_grad=1.0,
)
optimizer.zero_grad()
is_metric = batch[0]["is_metric"]
curr_num_view = len(batch)
del loss
tb_vis_img = (data_iter_step + 1) % accum_iter == 0 and (
(step + 1) % (args.print_img_freq)
) == 0
if not tb_vis_img:
del batch
else:
torch.cuda.empty_cache()
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(epoch=epoch_f)
metric_logger.update(lr=lr)
metric_logger.update(step=step)
metric_logger.update(loss=loss_value, **loss_details)
if (data_iter_step + 1) % accum_iter == 0 and (
(data_iter_step + 1) % (accum_iter * args.print_freq)
) == 0:
loss_value_reduce = accelerator.gather(
torch.tensor(loss_value).to(accelerator.device)
).mean() # MUST BE EXECUTED BY ALL NODES
if log_writer is None:
continue
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int(epoch_f * 1000)
log_writer.add_scalar("train_loss", loss_value_reduce, step)
log_writer.add_scalar("train_lr", lr, step)
log_writer.add_scalar("train_iter", epoch_1000x, step)
for name, val in loss_details.items():
if isinstance(val, torch.Tensor):
if val.ndim > 0:
continue
if isinstance(val, dict):
continue
log_writer.add_scalar("train_" + name, val, step)
if tb_vis_img:
if log_writer is None:
continue
with torch.no_grad():
depths_self, gt_depths_self = get_render_results(
batch, result["pred"], self_view=True
)
depths_cross, gt_depths_cross = get_render_results(
batch, result["pred"], self_view=False
)
for k in range(len(batch)):
loss_details[f"self_pred_depth_{k+1}"] = (
depths_self[k].detach().cpu()
)
loss_details[f"self_gt_depth_{k+1}"] = (
gt_depths_self[k].detach().cpu()
)
loss_details[f"pred_depth_{k+1}"] = (
depths_cross[k].detach().cpu()
)
loss_details[f"gt_depth_{k+1}"] = (
gt_depths_cross[k].detach().cpu()
)
imgs_stacked_dict = get_vis_imgs_new(
loss_details, args.num_imgs_vis, curr_num_view, is_metric=is_metric
)
for name, imgs_stacked in imgs_stacked_dict.items():
log_writer.add_images(
"train" + "/" + name, imgs_stacked, step, dataformats="HWC"
)
del batch
if (
data_iter_step % int(args.save_freq * len(data_loader)) == 0
and data_iter_step != 0
and data_iter_step != len(data_loader) - 1
):
print("saving at step", data_iter_step)
save_model(epoch - 1, "last", float("inf"))
# gather the stats from all processes
metric_logger.synchronize_between_processes(accelerator)
printer.info("Averaged stats: %s", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def test_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Sized,
accelerator: Accelerator,
device: torch.device,
epoch: int,
args,
log_writer=None,
prefix="test",
):
model.eval()
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9))
header = "Test Epoch: [{}]".format(epoch)
if log_writer is not None:
printer.info("log_dir: {}".format(log_writer.log_dir))
if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
data_loader.dataset.set_epoch(0)
if (
hasattr(data_loader, "batch_sampler")
and hasattr(data_loader.batch_sampler, "batch_sampler")
and hasattr(data_loader.batch_sampler.batch_sampler, "set_epoch")
):
data_loader.batch_sampler.batch_sampler.set_epoch(0)
for _, batch in enumerate(
metric_logger.log_every(data_loader, args.print_freq, accelerator, header)
):
result = loss_of_one_batch(
batch,
model,
criterion,
accelerator,
symmetrize_batch=False,
use_amp=bool(args.amp),
)
loss_value, loss_details = result["loss"] # criterion returns two values
metric_logger.update(loss=float(loss_value), **loss_details)
printer.info("Averaged stats: %s", metric_logger)
aggs = [("avg", "global_avg"), ("med", "median")]
results = {
f"{k}_{tag}": getattr(meter, attr)
for k, meter in metric_logger.meters.items()
for tag, attr in aggs
}
if log_writer is not None:
for name, val in results.items():
if isinstance(val, torch.Tensor):
if val.ndim > 0:
continue
if isinstance(val, dict):
continue
log_writer.add_scalar(prefix + "_" + name, val, 1000 * epoch)
depths_self, gt_depths_self = get_render_results(
batch, result["pred"], self_view=True
)
depths_cross, gt_depths_cross = get_render_results(
batch, result["pred"], self_view=False
)
for k in range(len(batch)):
loss_details[f"self_pred_depth_{k+1}"] = depths_self[k].detach().cpu()
loss_details[f"self_gt_depth_{k+1}"] = gt_depths_self[k].detach().cpu()
loss_details[f"pred_depth_{k+1}"] = depths_cross[k].detach().cpu()
loss_details[f"gt_depth_{k+1}"] = gt_depths_cross[k].detach().cpu()
imgs_stacked_dict = get_vis_imgs_new(
loss_details,
args.num_imgs_vis,
args.num_test_views,
is_metric=batch[0]["is_metric"],
)
for name, imgs_stacked in imgs_stacked_dict.items():
log_writer.add_images(
prefix + "/" + name, imgs_stacked, 1000 * epoch, dataformats="HWC"
)
del loss_details, loss_value, batch
torch.cuda.empty_cache()
return results
def batch_append(original_list, new_list):
for sublist, new_item in zip(original_list, new_list):
sublist.append(new_item)
return original_list
def gen_mask_indicator(img_mask_list, ray_mask_list, num_views, h, w):
output = []
for img_mask, ray_mask in zip(img_mask_list, ray_mask_list):
out = torch.zeros((h, w * num_views, 3))
for i in range(num_views):
if img_mask[i] and not ray_mask[i]:
offset = 0
elif not img_mask[i] and ray_mask[i]:
offset = 1
else:
offset = 0.5
out[:, i * w : (i + 1) * w] += offset
output.append(out)
return output
def vis_and_cat(
gt_imgs,
pred_imgs,
cross_gt_depths,
cross_pred_depths,
self_gt_depths,
self_pred_depths,
cross_conf,
self_conf,
ray_indicator,
is_metric,
):
cross_depth_gt_min = torch.quantile(cross_gt_depths, 0.01).item()
cross_depth_gt_max = torch.quantile(cross_gt_depths, 0.99).item()
cross_depth_pred_min = torch.quantile(cross_pred_depths, 0.01).item()
cross_depth_pred_max = torch.quantile(cross_pred_depths, 0.99).item()
cross_depth_min = min(cross_depth_gt_min, cross_depth_pred_min)
cross_depth_max = max(cross_depth_gt_max, cross_depth_pred_max)
cross_gt_depths_vis = colorize(
cross_gt_depths,
range=(
(cross_depth_min, cross_depth_max)
if is_metric
else (cross_depth_gt_min, cross_depth_gt_max)
),
append_cbar=True,
)
cross_pred_depths_vis = colorize(
cross_pred_depths,
range=(
(cross_depth_min, cross_depth_max)
if is_metric
else (cross_depth_pred_min, cross_depth_pred_max)
),
append_cbar=True,
)
self_depth_gt_min = torch.quantile(self_gt_depths, 0.01).item()
self_depth_gt_max = torch.quantile(self_gt_depths, 0.99).item()
self_depth_pred_min = torch.quantile(self_pred_depths, 0.01).item()
self_depth_pred_max = torch.quantile(self_pred_depths, 0.99).item()
self_depth_min = min(self_depth_gt_min, self_depth_pred_min)
self_depth_max = max(self_depth_gt_max, self_depth_pred_max)
self_gt_depths_vis = colorize(
self_gt_depths,
range=(
(self_depth_min, self_depth_max)
if is_metric
else (self_depth_gt_min, self_depth_gt_max)
),
append_cbar=True,
)
self_pred_depths_vis = colorize(
self_pred_depths,
range=(
(self_depth_min, self_depth_max)
if is_metric
else (self_depth_pred_min, self_depth_pred_max)
),
append_cbar=True,
)
if len(cross_conf) > 0:
cross_conf_vis = colorize(cross_conf, append_cbar=True)
if len(self_conf) > 0:
self_conf_vis = colorize(self_conf, append_cbar=True)
gt_imgs_vis = torch.zeros_like(cross_gt_depths_vis)
gt_imgs_vis[: gt_imgs.shape[0], : gt_imgs.shape[1]] = gt_imgs
pred_imgs_vis = torch.zeros_like(cross_gt_depths_vis)
pred_imgs_vis[: pred_imgs.shape[0], : pred_imgs.shape[1]] = pred_imgs
ray_indicator_vis = torch.cat(
[
ray_indicator,
torch.zeros(
ray_indicator.shape[0],
cross_pred_depths_vis.shape[1] - ray_indicator.shape[1],
3,
),
],
dim=1,
)
out = torch.cat(
[
ray_indicator_vis,
gt_imgs_vis,
pred_imgs_vis,
self_gt_depths_vis,
self_pred_depths_vis,
self_conf_vis,
cross_gt_depths_vis,
cross_pred_depths_vis,
cross_conf_vis,
],
dim=0,
)
return out
def get_vis_imgs_new(loss_details, num_imgs_vis, num_views, is_metric):
ret_dict = {}
gt_img_list = [[] for _ in range(num_imgs_vis)]
pred_img_list = [[] for _ in range(num_imgs_vis)]
cross_gt_depth_list = [[] for _ in range(num_imgs_vis)]
cross_pred_depth_list = [[] for _ in range(num_imgs_vis)]
self_gt_depth_list = [[] for _ in range(num_imgs_vis)]
self_pred_depth_list = [[] for _ in range(num_imgs_vis)]
cross_view_conf_list = [[] for _ in range(num_imgs_vis)]
self_view_conf_list = [[] for _ in range(num_imgs_vis)]
cross_view_conf_exits = False
self_view_conf_exits = False
img_mask_list = [[] for _ in range(num_imgs_vis)]
ray_mask_list = [[] for _ in range(num_imgs_vis)]
if num_views > 30:
stride = 5
elif num_views > 20:
stride = 3
elif num_views > 10:
stride = 2
else:
stride = 1
for i in range(0, num_views, stride):
gt_imgs = 0.5 * (loss_details[f"gt_img{i+1}"] + 1)[:num_imgs_vis].detach().cpu()
width = gt_imgs.shape[2]
pred_imgs = (
0.5 * (loss_details[f"pred_rgb_{i+1}"] + 1)[:num_imgs_vis].detach().cpu()
)
gt_img_list = batch_append(gt_img_list, gt_imgs.unbind(dim=0))
pred_img_list = batch_append(pred_img_list, pred_imgs.unbind(dim=0))
cross_pred_depths = (
loss_details[f"pred_depth_{i+1}"][:num_imgs_vis].detach().cpu()
)
cross_gt_depths = (
loss_details[f"gt_depth_{i+1}"]
.to(gt_imgs.device)[:num_imgs_vis]
.detach()
.cpu()
)
cross_pred_depth_list = batch_append(
cross_pred_depth_list, cross_pred_depths.unbind(dim=0)
)
cross_gt_depth_list = batch_append(
cross_gt_depth_list, cross_gt_depths.unbind(dim=0)
)
self_gt_depths = (
loss_details[f"self_gt_depth_{i+1}"][:num_imgs_vis].detach().cpu()
)
self_pred_depths = (
loss_details[f"self_pred_depth_{i+1}"][:num_imgs_vis].detach().cpu()
)
self_gt_depth_list = batch_append(
self_gt_depth_list, self_gt_depths.unbind(dim=0)
)
self_pred_depth_list = batch_append(
self_pred_depth_list, self_pred_depths.unbind(dim=0)
)
if f"conf_{i+1}" in loss_details:
cross_view_conf = loss_details[f"conf_{i+1}"][:num_imgs_vis].detach().cpu()
cross_view_conf_list = batch_append(
cross_view_conf_list, cross_view_conf.unbind(dim=0)
)
cross_view_conf_exits = True
if f"self_conf_{i+1}" in loss_details:
self_view_conf = (
loss_details[f"self_conf_{i+1}"][:num_imgs_vis].detach().cpu()
)
self_view_conf_list = batch_append(
self_view_conf_list, self_view_conf.unbind(dim=0)
)
self_view_conf_exits = True
img_mask_list = batch_append(
img_mask_list,
loss_details[f"img_mask_{i+1}"][:num_imgs_vis].detach().cpu().unbind(dim=0),
)
ray_mask_list = batch_append(
ray_mask_list,
loss_details[f"ray_mask_{i+1}"][:num_imgs_vis].detach().cpu().unbind(dim=0),
)
# each element in the list is [H, num_views * W, (3)], the size of the list is num_imgs_vis
gt_img_list = [torch.cat(sublist, dim=1) for sublist in gt_img_list]
pred_img_list = [torch.cat(sublist, dim=1) for sublist in pred_img_list]
cross_pred_depth_list = [
torch.cat(sublist, dim=1) for sublist in cross_pred_depth_list
]
cross_gt_depth_list = [torch.cat(sublist, dim=1) for sublist in cross_gt_depth_list]
self_gt_depth_list = [torch.cat(sublist, dim=1) for sublist in self_gt_depth_list]
self_pred_depth_list = [
torch.cat(sublist, dim=1) for sublist in self_pred_depth_list
]
cross_view_conf_list = (
[torch.cat(sublist, dim=1) for sublist in cross_view_conf_list]
if cross_view_conf_exits
else []
)
self_view_conf_list = (
[torch.cat(sublist, dim=1) for sublist in self_view_conf_list]
if self_view_conf_exits
else []
)
# each elment in the list is [num_views,], the size of the list is num_imgs_vis
img_mask_list = [torch.stack(sublist, dim=0) for sublist in img_mask_list]
ray_mask_list = [torch.stack(sublist, dim=0) for sublist in ray_mask_list]
ray_indicator = gen_mask_indicator(
img_mask_list, ray_mask_list, len(img_mask_list[0]), 30, width
)
for i in range(num_imgs_vis):
out = vis_and_cat(
gt_img_list[i],
pred_img_list[i],
cross_gt_depth_list[i],
cross_pred_depth_list[i],
self_gt_depth_list[i],
self_pred_depth_list[i],
cross_view_conf_list[i],
self_view_conf_list[i],
ray_indicator[i],
is_metric[i],
)
ret_dict[f"imgs_{i}"] = out
return ret_dict
@hydra.main(
version_base=None,
config_path=str(os.path.dirname(os.path.abspath(__file__))) + "/../config",
config_name="train.yaml",
)
def run(cfg: OmegaConf):
OmegaConf.resolve(cfg)
logdir = pathlib.Path(cfg.logdir)
logdir.mkdir(parents=True, exist_ok=True)
train(cfg)
if __name__ == "__main__":
run()