Vista / vista /sample.py
Leonard Bruns
Add Vista example
d323598
from __future__ import annotations
import argparse
import json
import os
import random
import PIL
import torch
from pytorch_lightning import seed_everything
from torchvision import transforms
from . import sample_utils
VERSION2SPECS = {
"vwm": {"config": "configs/inference/vista.yaml", "ckpt": "ckpts/vista.safetensors"}
}
DATASET2SOURCES = {
"NUSCENES": {"data_root": "data/nuscenes", "anno_file": "annos/nuScenes_val.json"},
"IMG": {"data_root": "image_folder"},
}
def parse_args(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument("--version", type=str, default="vwm", help="model version")
parser.add_argument("--dataset", type=str, default="NUSCENES", help="dataset name")
parser.add_argument(
"--save", type=str, default="outputs", help="directory to save samples"
)
parser.add_argument(
"--action",
type=str,
default="free",
help="action mode for control, such as traj, cmd, steer, goal",
)
parser.add_argument(
"--n_rounds", type=int, default=1, help="number of sampling rounds"
)
parser.add_argument(
"--n_frames", type=int, default=25, help="number of frames for each round"
)
parser.add_argument(
"--n_conds",
type=int,
default=1,
help="number of initial condition frames for the first round",
)
parser.add_argument(
"--seed", type=int, default=23, help="random seed for seed_everything"
)
parser.add_argument(
"--height", type=int, default=576, help="target height of the generated video"
)
parser.add_argument(
"--width", type=int, default=1024, help="target width of the generated video"
)
parser.add_argument(
"--cfg_scale",
type=float,
default=2.5,
help="scale of the classifier-free guidance",
)
parser.add_argument(
"--cond_aug", type=float, default=0.0, help="strength of the noise augmentation"
)
parser.add_argument(
"--n_steps", type=int, default=50, help="number of sampling steps"
)
parser.add_argument(
"--rand_gen",
action="store_false",
help="whether to generate samples randomly or sequentially",
)
parser.add_argument(
"--low_vram", action="store_true", help="whether to save memory or not"
)
return parser
def get_sample(
selected_index=0, dataset_name="NUSCENES", num_frames=25, action_mode="free"
):
dataset_dict = DATASET2SOURCES[dataset_name]
action_dict = None
if dataset_name == "IMG":
image_list = os.listdir(dataset_dict["data_root"])
total_length = len(image_list)
while selected_index >= total_length:
selected_index -= total_length
image_file = image_list[selected_index]
path_list = [os.path.join(dataset_dict["data_root"], image_file)] * num_frames
else:
with open(dataset_dict["anno_file"]) as anno_json:
all_samples = json.load(anno_json)
total_length = len(all_samples)
while selected_index >= total_length:
selected_index -= total_length
sample_dict = all_samples[selected_index]
path_list = list()
if dataset_name == "NUSCENES":
for index in range(num_frames):
image_path = os.path.join(
dataset_dict["data_root"], sample_dict["frames"][index]
)
assert os.path.exists(image_path), image_path
path_list.append(image_path)
if action_mode != "free":
action_dict = dict()
if action_mode == "traj" or action_mode == "trajectory":
action_dict["trajectory"] = torch.tensor(sample_dict["traj"][2:])
elif action_mode == "cmd" or action_mode == "command":
action_dict["command"] = torch.tensor(sample_dict["cmd"])
elif action_mode == "steer":
# scene might be empty
if sample_dict["speed"]:
action_dict["speed"] = torch.tensor(sample_dict["speed"][1:])
# scene might be empty
if sample_dict["angle"]:
action_dict["angle"] = (
torch.tensor(sample_dict["angle"][1:]) / 780
)
elif action_mode == "goal":
# point might be invalid
if (
sample_dict["z"] > 0
and 0 < sample_dict["goal"][0] < 1600
and 0 < sample_dict["goal"][1] < 900
):
action_dict["goal"] = torch.tensor(
[
sample_dict["goal"][0] / 1600,
sample_dict["goal"][1] / 900,
]
)
else:
raise ValueError(f"Unsupported action mode {action_mode}")
else:
raise ValueError(f"Invalid dataset {dataset_name}")
return path_list, selected_index, total_length, action_dict
def load_img(file_name, target_height=320, target_width=576, device="cuda"):
if file_name is not None:
image = PIL.Image.open(file_name)
if not image.mode == "RGB":
image = image.convert("RGB")
else:
raise ValueError(f"Invalid image file {file_name}")
ori_w, ori_h = image.size
# print(f"Loaded input image of size ({ori_w}, {ori_h})")
if ori_w / ori_h > target_width / target_height:
tmp_w = int(target_width / target_height * ori_h)
left = (ori_w - tmp_w) // 2
right = (ori_w + tmp_w) // 2
image = image.crop((left, 0, right, ori_h))
elif ori_w / ori_h < target_width / target_height:
tmp_h = int(target_height / target_width * ori_w)
top = (ori_h - tmp_h) // 2
bottom = (ori_h + tmp_h) // 2
image = image.crop((0, top, ori_w, bottom))
image = image.resize((target_width, target_height), resample=PIL.Image.LANCZOS)
if not image.mode == "RGB":
image = image.convert("RGB")
image = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
)(image)
return image.to(device)
if __name__ == "__main__":
parser = parse_args()
opt, unknown = parser.parse_known_args()
sample_utils.set_lowvram_mode(opt.low_vram)
version_dict = VERSION2SPECS[opt.version]
model = sample_utils.init_model(version_dict)
unique_keys = set([x.input_key for x in model.conditioner.embedders])
sample_index = 0
while sample_index >= 0:
seed_everything(opt.seed)
frame_list, sample_index, dataset_length, action_dict = get_sample(
sample_index, opt.dataset, opt.n_frames, opt.action
)
img_seq = list()
for each_path in frame_list:
img = load_img(each_path, opt.height, opt.width)
img_seq.append(img)
images = torch.stack(img_seq)
value_dict = sample_utils.init_embedder_options(unique_keys)
cond_img = img_seq[0][None]
value_dict["cond_frames_without_noise"] = cond_img
value_dict["cond_aug"] = opt.cond_aug
value_dict["cond_frames"] = cond_img + opt.cond_aug * torch.randn_like(cond_img)
if action_dict is not None:
for key, value in action_dict.items():
value_dict[key] = value
if opt.n_rounds > 1:
guider = "TrianglePredictionGuider"
else:
guider = "VanillaCFG"
sampler = sample_utils.init_sampling(
guider=guider,
steps=opt.n_steps,
cfg_scale=opt.cfg_scale,
num_frames=opt.n_frames,
)
uc_keys = [
"cond_frames",
"cond_frames_without_noise",
"command",
"trajectory",
"speed",
"angle",
"goal",
]
out = sample_utils.do_sample(
images,
model,
sampler,
value_dict,
num_rounds=opt.n_rounds,
num_frames=opt.n_frames,
force_uc_zero_embeddings=uc_keys,
initial_cond_indices=[index for index in range(opt.n_conds)],
)
if isinstance(out, (tuple, list)):
samples, samples_z, inputs = out
virtual_path = os.path.join(opt.save, "virtual")
real_path = os.path.join(opt.save, "real")
sample_utils.perform_save_locally(
virtual_path, samples, "videos", opt.dataset, sample_index
)
sample_utils.perform_save_locally(
virtual_path, samples, "grids", opt.dataset, sample_index
)
sample_utils.perform_save_locally(
virtual_path, samples, "images", opt.dataset, sample_index
)
sample_utils.perform_save_locally(
real_path, inputs, "videos", opt.dataset, sample_index
)
sample_utils.perform_save_locally(
real_path, inputs, "grids", opt.dataset, sample_index
)
sample_utils.perform_save_locally(
real_path, inputs, "images", opt.dataset, sample_index
)
else:
raise TypeError
if opt.rand_gen:
sample_index += random.randint(1, dataset_length - 1)
else:
sample_index += 1
if dataset_length <= sample_index:
sample_index = -1