| import os | |
| import argparse | |
| from PIL import Image | |
| from glob import glob | |
| import numpy as np | |
| import json | |
| import torch | |
| import torchvision | |
| from torch.nn import functional as F | |
| from matplotlib import colormaps | |
| import math | |
| import scipy | |
| def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True): | |
| H, W = height, width | |
| S = shape if shape else [] | |
| if align_corners: | |
| x = torch.linspace(0, 1, W, device=device) | |
| y = torch.linspace(0, 1, H, device=device) | |
| if not normalize: | |
| x = x * (W - 1) | |
| y = y * (H - 1) | |
| else: | |
| x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device) | |
| y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device) | |
| if not normalize: | |
| x = x * W | |
| y = y * H | |
| x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W] | |
| x = x.view(*x_view).expand(*exp) | |
| y = y.view(*y_view).expand(*exp) | |
| grid = torch.stack([x, y], dim=-1) | |
| if dtype == "numpy": | |
| grid = grid.numpy() | |
| return grid | |
| def translation(frame, dx, dy, pad_value): | |
| C, H, W = frame.shape | |
| grid = get_grid(H, W, device=frame.device) | |
| grid[..., 0] = grid[..., 0] - (dx / (W - 1)) | |
| grid[..., 1] = grid[..., 1] - (dy / (H - 1)) | |
| frame = frame - pad_value | |
| frame = torch.nn.functional.grid_sample(frame[None], grid[None] * 2 - 1, mode='bilinear', align_corners=True)[0] | |
| frame = frame + pad_value | |
| return frame | |
| def project(pos, t, time_steps, heigh, width): | |
| T, H, W = time_steps, heigh, width | |
| pos = torch.stack([pos[..., 0] / (W - 1), pos[..., 1] / (H - 1)], dim=-1) | |
| pos = pos - 0.5 | |
| pos = pos * 0.25 | |
| t = 1 - torch.ones_like(pos[..., :1]) * t / (T - 1) | |
| pos = torch.cat([pos, t], dim=-1) | |
| M = torch.tensor([ | |
| [0.8, 0, 0.5], | |
| [-0.2, 1.0, 0.1], | |
| [0.0, 0.0, 0.0] | |
| ]) | |
| pos = pos @ M.t().to(pos.device) | |
| pos = pos[..., :2] | |
| pos[..., 0] += 0.25 | |
| pos[..., 1] += 0.45 | |
| pos[..., 0] *= (W - 1) | |
| pos[..., 1] *= (H - 1) | |
| return pos | |
| def draw(pos, vis, col, height, width, radius=1): | |
| H, W = height, width | |
| frame = torch.zeros(H * W, 4, device=pos.device) | |
| pos = pos[vis.bool()] | |
| col = col[vis.bool()] | |
| if radius > 1: | |
| pos, col = get_radius_neighbors(pos, col, radius) | |
| else: | |
| pos, col = get_cardinal_neighbors(pos, col) | |
| inbound = (pos[:, 0] >= 0) & (pos[:, 0] <= W - 1) & (pos[:, 1] >= 0) & (pos[:, 1] <= H - 1) | |
| pos = pos[inbound] | |
| col = col[inbound] | |
| pos = pos.round().long() | |
| idx = pos[:, 1] * W + pos[:, 0] | |
| idx = idx.view(-1, 1).expand(-1, 4) | |
| frame.scatter_add_(0, idx, col) | |
| frame = frame.view(H, W, 4) | |
| frame, alpha = frame[..., :3], frame[..., 3] | |
| nonzero = alpha > 0 | |
| frame[nonzero] /= alpha[nonzero][..., None] | |
| alpha = nonzero[..., None].float() | |
| return frame, alpha | |
| def get_cardinal_neighbors(pos, col, eps=0.01): | |
| pos_nw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor()], dim=-1) | |
| pos_sw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor() + 1], dim=-1) | |
| pos_ne = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor()], dim=-1) | |
| pos_se = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor() + 1], dim=-1) | |
| w_n = pos[:, 1].floor() + 1 - pos[:, 1] + eps | |
| w_s = pos[:, 1] - pos[:, 1].floor() + eps | |
| w_w = pos[:, 0].floor() + 1 - pos[:, 0] + eps | |
| w_e = pos[:, 0] - pos[:, 0].floor() + eps | |
| w_nw = (w_n * w_w)[:, None] | |
| w_sw = (w_s * w_w)[:, None] | |
| w_ne = (w_n * w_e)[:, None] | |
| w_se = (w_s * w_e)[:, None] | |
| col_nw = torch.cat([w_nw * col, w_nw], dim=-1) | |
| col_sw = torch.cat([w_sw * col, w_sw], dim=-1) | |
| col_ne = torch.cat([w_ne * col, w_ne], dim=-1) | |
| col_se = torch.cat([w_se * col, w_se], dim=-1) | |
| pos = torch.cat([pos_nw, pos_sw, pos_ne, pos_se], dim=0) | |
| col = torch.cat([col_nw, col_sw, col_ne, col_se], dim=0) | |
| return pos, col | |
| def get_radius_neighbors(pos, col, radius): | |
| R = math.ceil(radius) | |
| center = torch.stack([pos[:, 0].round(), pos[:, 1].round()], dim=-1) | |
| nn = torch.arange(-R, R + 1) | |
| nn = torch.stack([nn[None, :].expand(2 * R + 1, -1), nn[:, None].expand(-1, 2 * R + 1)], dim=-1) | |
| nn = nn.view(-1, 2).cuda() | |
| in_radius = nn[:, 0] ** 2 + nn[:, 1] ** 2 <= radius ** 2 | |
| nn = nn[in_radius] | |
| w = 1 - nn.pow(2).sum(-1).sqrt() / radius + 0.01 | |
| w = w[None].expand(pos.size(0), -1).reshape(-1) | |
| pos = (center.view(-1, 1, 2) + nn.view(1, -1, 2)).view(-1, 2) | |
| col = col.view(-1, 1, 3).repeat(1, nn.size(0), 1) | |
| col = col.view(-1, 3) | |
| col = torch.cat([col * w[:, None], w[:, None]], dim=-1) | |
| return pos, col | |
| def get_rainbow_colors(size): | |
| col_map = colormaps["jet"] | |
| col_range = np.array(range(size)) / (size - 1) | |
| col = torch.from_numpy(col_map(col_range)[..., :3]).float() | |
| col = col.view(-1, 3) | |
| return col | |
| def spline_interpolation(x, length=10): | |
| if length != 1: | |
| T, N, C = x.shape | |
| x = x.view(T, -1).cpu().numpy() | |
| original_time = np.arange(T) | |
| cs = scipy.interpolate.CubicSpline(original_time, x) | |
| new_time = np.linspace(original_time[0], original_time[-1], T * length) | |
| x = torch.from_numpy(cs(new_time)).view(-1, N, C).float().cuda() | |
| return x | |
| def create_folder(path, verbose=False, exist_ok=True, safe=True): | |
| if os.path.exists(path) and not exist_ok: | |
| if not safe: | |
| raise OSError | |
| return False | |
| try: | |
| os.makedirs(path) | |
| except: | |
| if not safe: | |
| raise OSError | |
| return False | |
| if verbose: | |
| print(f"Created folder: {path}") | |
| return True | |
| def write_video_to_file(video, path, channels): | |
| create_folder(os.path.dirname(path)) | |
| if channels == "first": | |
| video = video.permute(0, 2, 3, 1) | |
| video = (video.cpu() * 255.).to(torch.uint8) | |
| torchvision.io.write_video(path, video, 8, "h264", options={"pix_fmt": "yuv420p", "crf": "23"}) | |
| return video | |
| def write_frame(frame, path, channels="first"): | |
| create_folder(os.path.dirname(path)) | |
| frame = frame.cpu().numpy() | |
| if channels == "first": | |
| frame = np.transpose(frame, (1, 2, 0)) | |
| frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8) | |
| frame = Image.fromarray(frame) | |
| frame.save(path) | |
| def write_video_to_folder(video, path, channels, zero_padded, ext): | |
| create_folder(path) | |
| time_steps = video.shape[0] | |
| for step in range(time_steps): | |
| pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else "" | |
| frame_path = os.path.join(path, f"{pad}{step}.{ext}") | |
| write_frame(video[step], frame_path, channels) | |
| def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"): | |
| if dtype == "numpy": | |
| video = torch.from_numpy(video) | |
| if path.endswith(".mp4"): | |
| write_video_to_file(video, path, channels) | |
| else: | |
| write_video_to_folder(video, path, channels, zero_padded, ext) | |