Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import gzip | |
| import torch | |
| import numpy as np | |
| import torch.utils.data as data | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Any, Dict, Tuple | |
| from cotracker.datasets.utils import CoTrackerData | |
| from cotracker.datasets.dataclass_utils import load_dataclass | |
| class ImageAnnotation: | |
| # path to jpg file, relative w.r.t. dataset_root | |
| path: str | |
| # H x W | |
| size: Tuple[int, int] | |
| class DynamicReplicaFrameAnnotation: | |
| """A dataclass used to load annotations from json.""" | |
| # can be used to join with `SequenceAnnotation` | |
| sequence_name: str | |
| # 0-based, continuous frame number within sequence | |
| frame_number: int | |
| # timestamp in seconds from the video start | |
| frame_timestamp: float | |
| image: ImageAnnotation | |
| meta: Optional[Dict[str, Any]] = None | |
| camera_name: Optional[str] = None | |
| trajectories: Optional[str] = None | |
| class DynamicReplicaDataset(data.Dataset): | |
| def __init__( | |
| self, | |
| root, | |
| split="valid", | |
| traj_per_sample=256, | |
| crop_size=None, | |
| sample_len=-1, | |
| only_first_n_samples=-1, | |
| rgbd_input=False, | |
| ): | |
| super(DynamicReplicaDataset, self).__init__() | |
| self.root = root | |
| self.sample_len = sample_len | |
| self.split = split | |
| self.traj_per_sample = traj_per_sample | |
| self.rgbd_input = rgbd_input | |
| self.crop_size = crop_size | |
| frame_annotations_file = f"frame_annotations_{split}.jgz" | |
| self.sample_list = [] | |
| with gzip.open( | |
| os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" | |
| ) as zipfile: | |
| frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) | |
| seq_annot = defaultdict(list) | |
| for frame_annot in frame_annots_list: | |
| if frame_annot.camera_name == "left": | |
| seq_annot[frame_annot.sequence_name].append(frame_annot) | |
| for seq_name in seq_annot.keys(): | |
| seq_len = len(seq_annot[seq_name]) | |
| step = self.sample_len if self.sample_len > 0 else seq_len | |
| counter = 0 | |
| for ref_idx in range(0, seq_len, step): | |
| sample = seq_annot[seq_name][ref_idx : ref_idx + step] | |
| self.sample_list.append(sample) | |
| counter += 1 | |
| if only_first_n_samples > 0 and counter >= only_first_n_samples: | |
| break | |
| def __len__(self): | |
| return len(self.sample_list) | |
| def crop(self, rgbs, trajs): | |
| T, N, _ = trajs.shape | |
| S = len(rgbs) | |
| H, W = rgbs[0].shape[:2] | |
| assert S == T | |
| H_new = H | |
| W_new = W | |
| # simple random crop | |
| y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 | |
| x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 | |
| rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] | |
| trajs[:, :, 0] -= x0 | |
| trajs[:, :, 1] -= y0 | |
| return rgbs, trajs | |
| def __getitem__(self, index): | |
| sample = self.sample_list[index] | |
| T = len(sample) | |
| rgbs, visibilities, traj_2d = [], [], [] | |
| H, W = sample[0].image.size | |
| image_size = (H, W) | |
| for i in range(T): | |
| traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) | |
| traj = torch.load(traj_path) | |
| visibilities.append(traj["verts_inds_vis"].numpy()) | |
| rgbs.append(traj["img"].numpy()) | |
| traj_2d.append(traj["traj_2d"].numpy()[..., :2]) | |
| traj_2d = np.stack(traj_2d) | |
| visibility = np.stack(visibilities) | |
| T, N, D = traj_2d.shape | |
| # subsample trajectories for augmentations | |
| visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] | |
| traj_2d = traj_2d[:, visible_inds_sampled] | |
| visibility = visibility[:, visible_inds_sampled] | |
| if self.crop_size is not None: | |
| rgbs, traj_2d = self.crop(rgbs, traj_2d) | |
| H, W, _ = rgbs[0].shape | |
| image_size = self.crop_size | |
| visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False | |
| visibility[traj_2d[:, :, 0] < 0] = False | |
| visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False | |
| visibility[traj_2d[:, :, 1] < 0] = False | |
| # filter out points that're visible for less than 10 frames | |
| visible_inds_resampled = visibility.sum(0) > 10 | |
| traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) | |
| visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) | |
| rgbs = np.stack(rgbs, 0) | |
| video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() | |
| return CoTrackerData( | |
| video=video, | |
| trajectory=traj_2d, | |
| visibility=visibility, | |
| valid=torch.ones(T, N), | |
| seq_name=sample[0].sequence_name, | |
| ) | |