Spaces:
Running
Running
| import os | |
| import math | |
| import os.path as osp | |
| import random | |
| import pickle | |
| import warnings | |
| import glob | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.utils.data as data | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torchvision.datasets.video_utils import VideoClips | |
| IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] | |
| VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v'] | |
| def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1, | |
| batch_size=16, num_workers=8): | |
| data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers) | |
| loader = data._dataloader() | |
| return loader | |
| def is_image_file(filename): | |
| return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | |
| def get_parent_dir(path): | |
| return osp.basename(osp.dirname(path)) | |
| def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1): | |
| # video: THWC, {0, ..., 255} | |
| assert in_channels == 3 | |
| video = video.permute(0, 3, 1, 2).float() / 255. # TCHW | |
| t, c, h, w = video.shape | |
| # temporal crop | |
| if sequence_length is not None: | |
| assert sequence_length <= t | |
| video = video[:sequence_length] | |
| # skip frames | |
| if sample_every_n_frames > 1: | |
| video = video[::sample_every_n_frames] | |
| # scale shorter side to resolution | |
| scale = resolution / min(h, w) | |
| if h < w: | |
| target_size = (resolution, math.ceil(w * scale)) | |
| else: | |
| target_size = (math.ceil(h * scale), resolution) | |
| video = F.interpolate(video, size=target_size, mode='bilinear', | |
| align_corners=False, antialias=True) | |
| # center crop | |
| t, c, h, w = video.shape | |
| w_start = (w - resolution) // 2 | |
| h_start = (h - resolution) // 2 | |
| video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] | |
| video = video.permute(1, 0, 2, 3).contiguous() # CTHW | |
| return {'video': video} | |
| def preprocess_image(image): | |
| # [0, 1] => [-1, 1] | |
| img = torch.from_numpy(image) | |
| return img | |
| class VideoData(data.Dataset): | |
| """ Class to create dataloaders for video datasets | |
| Args: | |
| data_path: Path to the folder with video frames or videos. | |
| image_folder: If True, the data is stored as images in folders. | |
| resolution: Resolution of the returned videos. | |
| sequence_length: Length of extracted video sequences. | |
| sample_every_n_frames: Sample every n frames from the video. | |
| batch_size: Batch size. | |
| num_workers: Number of workers for the dataloader. | |
| shuffle: If True, shuffle the data. | |
| """ | |
| def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int, | |
| sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True): | |
| super().__init__() | |
| self.data_path = data_path | |
| self.image_folder = image_folder | |
| self.resolution = resolution | |
| self.sequence_length = sequence_length | |
| self.sample_every_n_frames = sample_every_n_frames | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.shuffle = shuffle | |
| def _dataset(self): | |
| ''' | |
| Initializes and return the dataset. | |
| ''' | |
| if self.image_folder: | |
| Dataset = FrameDataset | |
| dataset = Dataset(self.data_path, self.sequence_length, | |
| resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames) | |
| else: | |
| Dataset = VideoDataset | |
| dataset = Dataset(self.data_path, self.sequence_length, | |
| resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames) | |
| return dataset | |
| def _dataloader(self): | |
| ''' | |
| Initializes and returns the dataloader. | |
| ''' | |
| dataset = self._dataset() | |
| if dist.is_initialized(): | |
| sampler = data.distributed.DistributedSampler( | |
| dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank() | |
| ) | |
| else: | |
| sampler = None | |
| dataloader = data.DataLoader( | |
| dataset, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| pin_memory=True, | |
| sampler=sampler, | |
| shuffle=sampler is None and self.shuffle is True | |
| ) | |
| return dataloader | |
| class VideoDataset(data.Dataset): | |
| """ | |
| Generic dataset for videos files stored in folders. | |
| Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory. | |
| The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos. | |
| Returns BCTHW videos in the range [0, 1]. | |
| Args: | |
| data_folder: Path to the folder with corresponding videos stored. | |
| sequence_length: Length of extracted video sequences. | |
| resolution: Resolution of the returned videos. | |
| sample_every_n_frames: Sample every n frames from the video. | |
| """ | |
| def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1): | |
| super().__init__() | |
| self.sequence_length = sequence_length | |
| self.resolution = resolution | |
| self.sample_every_n_frames = sample_every_n_frames | |
| folder = data_folder | |
| files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True) | |
| for ext in VID_EXTENSIONS], []) | |
| warnings.filterwarnings('ignore') | |
| cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl") | |
| if not osp.exists(cache_file): | |
| clips = VideoClips(files, sequence_length, num_workers=4) | |
| try: | |
| pickle.dump(clips.metadata, open(cache_file, 'wb')) | |
| except: | |
| print(f"Failed to save metadata to {cache_file}") | |
| else: | |
| metadata = pickle.load(open(cache_file, 'rb')) | |
| clips = VideoClips(files, sequence_length, | |
| _precomputed_metadata=metadata) | |
| self._clips = clips | |
| # instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos | |
| self._clips.get_clip_location = self.get_random_clip_from_video | |
| def get_random_clip_from_video(self, idx: int) -> tuple: | |
| ''' | |
| Sample a random clip starting index from the video. | |
| Args: | |
| idx: Index of the video. | |
| ''' | |
| # Note that some videos may not contain enough frames, we skip those videos here. | |
| while self._clips.clips[idx].shape[0] <= 0: | |
| idx += 1 | |
| n_clip = self._clips.clips[idx].shape[0] | |
| clip_id = random.randint(0, n_clip - 1) | |
| return idx, clip_id | |
| def __len__(self): | |
| return self._clips.num_videos() | |
| def __getitem__(self, idx): | |
| resolution = self.resolution | |
| while True: | |
| try: | |
| video, _, _, idx = self._clips.get_clip(idx) | |
| except Exception as e: | |
| print(idx, e) | |
| idx = (idx + 1) % self._clips.num_clips() | |
| continue | |
| break | |
| return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames)) | |
| class FrameDataset(data.Dataset): | |
| """ | |
| Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders | |
| in the provided directory. Each leaf folder is assumed to contain frames from a single video. | |
| Args: | |
| data_folder: path to the folder with video frames. The folder | |
| should contain folders with frames from each video. | |
| sequence_length: length of extracted video sequences | |
| resolution: resolution of the returned videos | |
| sample_every_n_frames: sample every n frames from the video | |
| """ | |
| def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1): | |
| self.resolution = resolution | |
| self.sequence_length = sequence_length | |
| self.sample_every_n_frames = sample_every_n_frames | |
| self.data_all = self.load_video_frames(data_folder) | |
| self.video_num = len(self.data_all) | |
| def __getitem__(self, index): | |
| batch_data = self.getTensor(index) | |
| return_list = {'video': batch_data} | |
| return return_list | |
| def load_video_frames(self, dataroot: str) -> list: | |
| ''' | |
| Loads all the video frames under the dataroot and returns a list of all the video frames. | |
| Args: | |
| dataroot: The root directory containing the video frames. | |
| Returns: | |
| A list of all the video frames. | |
| ''' | |
| data_all = [] | |
| frame_list = os.walk(dataroot) | |
| for _, meta in enumerate(frame_list): | |
| root = meta[0] | |
| try: | |
| frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1])) | |
| except: | |
| print(meta[0], meta[2]) | |
| if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames): | |
| continue | |
| frames = [ | |
| os.path.join(root, item) for item in frames | |
| if is_image_file(item) | |
| ] | |
| if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames): | |
| data_all.append(frames) | |
| return data_all | |
| def getTensor(self, index: int) -> torch.Tensor: | |
| ''' | |
| Returns a tensor of the video frames at the given index. | |
| Args: | |
| index: The index of the video frames to return. | |
| Returns: | |
| A BCTHW tensor in the range `[0, 1]` of the video frames at the given index. | |
| ''' | |
| video = self.data_all[index] | |
| video_len = len(video) | |
| # load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1 | |
| if self.sequence_length == -1: | |
| assert self.sample_every_n_frames == 1 | |
| start_idx = 0 | |
| end_idx = video_len | |
| else: | |
| n_frames_interval = self.sequence_length * self.sample_every_n_frames | |
| start_idx = random.randint(0, video_len - n_frames_interval) | |
| end_idx = start_idx + n_frames_interval | |
| img = Image.open(video[0]) | |
| h, w = img.height, img.width | |
| if h > w: | |
| half = (h - w) // 2 | |
| cropsize = (0, half, w, half + w) # left, upper, right, lower | |
| elif w > h: | |
| half = (w - h) // 2 | |
| cropsize = (half, 0, half + h, h) | |
| images = [] | |
| for i in range(start_idx, end_idx, | |
| self.sample_every_n_frames): | |
| path = video[i] | |
| img = Image.open(path) | |
| if h != w: | |
| img = img.crop(cropsize) | |
| img = img.resize( | |
| (self.resolution, self.resolution), | |
| Image.ANTIALIAS) | |
| img = np.asarray(img, dtype=np.float32) | |
| img /= 255. | |
| img_tensor = preprocess_image(img).unsqueeze(0) | |
| images.append(img_tensor) | |
| video_clip = torch.cat(images).permute(3, 0, 1, 2) | |
| return video_clip | |
| def __len__(self): | |
| return self.video_num | |