Spaces:
Runtime error
Runtime error
| from matplotlib import collections | |
| import json | |
| import os | |
| import copy | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| from tqdm import tqdm | |
| from random import sample | |
| import torchaudio | |
| import logging | |
| import collections | |
| from glob import glob | |
| import sys | |
| import albumentations | |
| import soundfile | |
| sys.path.insert(0, '.') # nopep8 | |
| from train import instantiate_from_config | |
| from foleycrafter.models.specvqgan.data.transforms import * | |
| torchaudio.set_audio_backend("sox_io") | |
| logger = logging.getLogger(f'main.{__name__}') | |
| SR = 22050 | |
| FPS = 15 | |
| MAX_SAMPLE_ITER = 10 | |
| def non_negative(x): return int(np.round(max(0, x), 0)) | |
| def rms(x): return np.sqrt(np.mean(x**2)) | |
| def get_GH_data_identifier(video_name, start_idx, split='_'): | |
| if isinstance(start_idx, str): | |
| return video_name + split + start_idx | |
| elif isinstance(start_idx, int): | |
| return video_name + split + str(start_idx) | |
| else: | |
| raise NotImplementedError | |
| class Crop(object): | |
| def __init__(self, cropped_shape=None, random_crop=False): | |
| self.cropped_shape = cropped_shape | |
| if cropped_shape is not None: | |
| mel_num, spec_len = cropped_shape | |
| if random_crop: | |
| self.cropper = albumentations.RandomCrop | |
| else: | |
| self.cropper = albumentations.CenterCrop | |
| self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)]) | |
| else: | |
| self.preprocessor = lambda **kwargs: kwargs | |
| def __call__(self, item): | |
| item['image'] = self.preprocessor(image=item['image'])['image'] | |
| if 'cond_image' in item.keys(): | |
| item['cond_image'] = self.preprocessor(image=item['cond_image'])['image'] | |
| return item | |
| class CropImage(Crop): | |
| def __init__(self, *crop_args): | |
| super().__init__(*crop_args) | |
| class CropFeats(Crop): | |
| def __init__(self, *crop_args): | |
| super().__init__(*crop_args) | |
| def __call__(self, item): | |
| item['feature'] = self.preprocessor(image=item['feature'])['image'] | |
| return item | |
| class CropCoords(Crop): | |
| def __init__(self, *crop_args): | |
| super().__init__(*crop_args) | |
| def __call__(self, item): | |
| item['coord'] = self.preprocessor(image=item['coord'])['image'] | |
| return item | |
| class ResampleFrames(object): | |
| def __init__(self, feat_sample_size, times_to_repeat_after_resample=None): | |
| self.feat_sample_size = feat_sample_size | |
| self.times_to_repeat_after_resample = times_to_repeat_after_resample | |
| def __call__(self, item): | |
| feat_len = item['feature'].shape[0] | |
| ## resample | |
| assert feat_len >= self.feat_sample_size | |
| # evenly spaced points (abcdefghkl -> aoooofoooo) | |
| idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False) | |
| # xoooo xoooo -> ooxoo ooxoo | |
| shift = feat_len // (self.feat_sample_size + 1) | |
| idx = idx + shift | |
| ## repeat after resampling (abc -> aaaabbbbcccc) | |
| if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1: | |
| idx = np.repeat(idx, self.times_to_repeat_after_resample) | |
| item['feature'] = item['feature'][idx, :] | |
| return item | |
| class GreatestHitSpecs(torch.utils.data.Dataset): | |
| def __init__(self, split, spec_dir_path, spec_len, random_crop, mel_num, | |
| spec_crop_len, L=2.0, rand_shift=False, spec_transforms=None, splits_path='./data', | |
| meta_path='./data/info_r2plus1d_dim1024_15fps.json'): | |
| super().__init__() | |
| self.split = split | |
| self.specs_dir = spec_dir_path | |
| self.spec_transforms = spec_transforms | |
| self.splits_path = splits_path | |
| self.meta_path = meta_path | |
| self.spec_len = spec_len | |
| self.rand_shift = rand_shift | |
| self.L = L | |
| self.spec_take_first = int(math.ceil(860 * (L / 10.) / 32) * 32) | |
| self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first | |
| greatesthit_meta = json.load(open(self.meta_path, 'r')) | |
| unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type']))) | |
| self.label2target = {label: target for target, label in enumerate(unique_classes)} | |
| self.target2label = {target: label for label, target in self.label2target.items()} | |
| self.video_idx2label = { | |
| get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
| greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name'])) | |
| } | |
| self.available_video_hit = list(self.video_idx2label.keys()) | |
| self.video_idx2path = { | |
| vh: os.path.join(self.specs_dir, | |
| vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy') | |
| for vh in self.available_video_hit | |
| } | |
| self.video_idx2idx = { | |
| get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
| i for i in range(len(greatesthit_meta['video_name'])) | |
| } | |
| split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
| if not os.path.exists(split_clip_ids_path): | |
| raise NotImplementedError() | |
| clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
| self.dataset = clip_video_hit | |
| spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len | |
| self.spec_transforms = transforms.Compose([ | |
| CropImage([mel_num, spec_crop_len], random_crop), | |
| # transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=0), | |
| # transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=0) | |
| ]) | |
| self.video2indexes = {} | |
| for video_idx in self.dataset: | |
| video, start_idx = video_idx.split('_') | |
| if video not in self.video2indexes.keys(): | |
| self.video2indexes[video] = [] | |
| self.video2indexes[video].append(start_idx) | |
| for video in self.video2indexes.keys(): | |
| if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
| self.dataset.remove( | |
| get_GH_data_identifier(video, self.video2indexes[video][0]) | |
| ) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video_idx = self.dataset[idx] | |
| spec_path = self.video_idx2path[video_idx] | |
| spec = np.load(spec_path) # (80, 860) | |
| if self.rand_shift: | |
| shift = random.uniform(0, 0.5) | |
| spec_shift = int(shift * spec.shape[1] // 10) | |
| # Since only the first second is used | |
| spec = np.roll(spec, -spec_shift, 1) | |
| # concat spec outside dataload | |
| item['image'] = 2 * spec - 1 # (80, 860) | |
| item['image'] = item['image'][:, :self.spec_take_first] | |
| item['file_path'] = spec_path | |
| item['label'] = self.video_idx2label[video_idx] | |
| item['target'] = self.label2target[item['label']] | |
| if self.spec_transforms is not None: | |
| item = self.spec_transforms(item) | |
| return item | |
| class GreatestHitSpecsTrain(GreatestHitSpecs): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('train', **specs_dataset_cfg) | |
| class GreatestHitSpecsValidation(GreatestHitSpecs): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('val', **specs_dataset_cfg) | |
| class GreatestHitSpecsTest(GreatestHitSpecs): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('test', **specs_dataset_cfg) | |
| class GreatestHitWave(torch.utils.data.Dataset): | |
| def __init__(self, split, wav_dir, random_crop, mel_num, spec_crop_len, spec_len, | |
| L=2.0, splits_path='./data', rand_shift=True, | |
| data_path='data/greatesthit/greatesthit-process-resized'): | |
| super().__init__() | |
| self.split = split | |
| self.wav_dir = wav_dir | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.L = L | |
| self.rand_shift = rand_shift | |
| split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
| if not os.path.exists(split_clip_ids_path): | |
| raise NotImplementedError() | |
| clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
| video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit])) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) // 2 for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name} | |
| self.dataset = clip_video_hit | |
| self.video2indexes = {} | |
| for video_idx in self.dataset: | |
| video, start_idx = video_idx.split('_') | |
| if video not in self.video2indexes.keys(): | |
| self.video2indexes[video] = [] | |
| self.video2indexes[video].append(start_idx) | |
| for video in self.video2indexes.keys(): | |
| if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
| self.dataset.remove( | |
| get_GH_data_identifier(video, self.video2indexes[video][0]) | |
| ) | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| Padding(target_len=int(SR * self.L)), | |
| ]) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video_idx = self.dataset[idx] | |
| video, start_idx = video_idx.split('_') | |
| start_idx = int(start_idx) | |
| if self.rand_shift: | |
| shift = int(random.uniform(-0.5, 0.5) * SR) | |
| start_idx = non_negative(start_idx + shift) | |
| wave_path = self.video_audio_path[video] | |
| wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx) | |
| assert sr == SR | |
| wav = self.wav_transforms(wav) | |
| item['image'] = wav # (44100,) | |
| # item['wav'] = wav | |
| item['file_path_wav_'] = wave_path | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| class GreatestHitWaveTrain(GreatestHitWave): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('train', **specs_dataset_cfg) | |
| class GreatestHitWaveValidation(GreatestHitWave): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('val', **specs_dataset_cfg) | |
| class GreatestHitWaveTest(GreatestHitWave): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('test', **specs_dataset_cfg) | |
| class CondGreatestHitSpecsCondOnImage(torch.utils.data.Dataset): | |
| def __init__(self, split, specs_dir, spec_len, feat_len, feat_depth, feat_crop_len, random_crop, mel_num, spec_crop_len, | |
| vqgan_L=10.0, L=1.0, rand_shift=False, spec_transforms=None, frame_transforms=None, splits_path='./data', | |
| meta_path='./data/info_r2plus1d_dim1024_15fps.json', frame_path='data/greatesthit/greatesthit_processed', | |
| p_outside_cond=0., p_audio_aug=0.5): | |
| super().__init__() | |
| self.split = split | |
| self.specs_dir = specs_dir | |
| self.spec_transforms = spec_transforms | |
| self.frame_transforms = frame_transforms | |
| self.splits_path = splits_path | |
| self.meta_path = meta_path | |
| self.frame_path = frame_path | |
| self.feat_len = feat_len | |
| self.feat_depth = feat_depth | |
| self.feat_crop_len = feat_crop_len | |
| self.spec_len = spec_len | |
| self.rand_shift = rand_shift | |
| self.L = L | |
| self.spec_take_first = int(math.ceil(860 * (vqgan_L / 10.) / 32) * 32) | |
| self.spec_take_first = 860 if self.spec_take_first > 860 else self.spec_take_first | |
| self.p_outside_cond = torch.tensor(p_outside_cond) | |
| greatesthit_meta = json.load(open(self.meta_path, 'r')) | |
| unique_classes = sorted(list(set(ht for ht in greatesthit_meta['hit_type']))) | |
| self.label2target = {label: target for target, label in enumerate(unique_classes)} | |
| self.target2label = {target: label for label, target in self.label2target.items()} | |
| self.video_idx2label = { | |
| get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
| greatesthit_meta['hit_type'][i] for i in range(len(greatesthit_meta['video_name'])) | |
| } | |
| self.available_video_hit = list(self.video_idx2label.keys()) | |
| self.video_idx2path = { | |
| vh: os.path.join(self.specs_dir, | |
| vh.replace('_', '_denoised_') + '_' + self.video_idx2label[vh].replace(' ', '_') +'_mel.npy') | |
| for vh in self.available_video_hit | |
| } | |
| for value in self.video_idx2path.values(): | |
| assert os.path.exists(value) | |
| self.video_idx2idx = { | |
| get_GH_data_identifier(greatesthit_meta['video_name'][i], greatesthit_meta['start_idx'][i]): | |
| i for i in range(len(greatesthit_meta['video_name'])) | |
| } | |
| split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
| if not os.path.exists(split_clip_ids_path): | |
| self.make_split_files() | |
| clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
| self.dataset = clip_video_hit | |
| spec_crop_len = self.spec_take_first if self.spec_take_first <= spec_crop_len else spec_crop_len | |
| self.spec_transforms = transforms.Compose([ | |
| CropImage([mel_num, spec_crop_len], random_crop), | |
| # transforms.RandomApply([FrequencyMasking(freq_mask_param=20)], p=p_audio_aug), | |
| # transforms.RandomApply([TimeMasking(time_mask_param=int(32 * self.L))], p=p_audio_aug) | |
| ]) | |
| if self.frame_transforms == None: | |
| self.frame_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.1, saturation=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| self.video2indexes = {} | |
| for video_idx in self.dataset: | |
| video, start_idx = video_idx.split('_') | |
| if video not in self.video2indexes.keys(): | |
| self.video2indexes[video] = [] | |
| self.video2indexes[video].append(start_idx) | |
| for video in self.video2indexes.keys(): | |
| if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
| self.dataset.remove( | |
| get_GH_data_identifier(video, self.video2indexes[video][0]) | |
| ) | |
| clip_classes = [self.label2target[self.video_idx2label[vh]] for vh in clip_video_hit] | |
| class2count = collections.Counter(clip_classes) | |
| self.class_counts = torch.tensor([class2count[cls] for cls in range(len(class2count))]) | |
| if self.L != 1.0: | |
| print(split, L) | |
| self.validate_data() | |
| self.video2indexes = {} | |
| for video_idx in self.dataset: | |
| video, start_idx = video_idx.split('_') | |
| if video not in self.video2indexes.keys(): | |
| self.video2indexes[video] = [] | |
| self.video2indexes[video].append(start_idx) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| try: | |
| video_idx = self.dataset[idx] | |
| spec_path = self.video_idx2path[video_idx] | |
| spec = np.load(spec_path) # (80, 860) | |
| video, start_idx = video_idx.split('_') | |
| frame_path = os.path.join(self.frame_path, video, 'frames') | |
| start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
| end_frame_idx = non_negative(start_frame_idx + FPS * self.L) | |
| if self.rand_shift: | |
| shift = random.uniform(0, 0.5) | |
| spec_shift = int(shift * spec.shape[1] // 10) | |
| # Since only the first second is used | |
| spec = np.roll(spec, -spec_shift, 1) | |
| start_frame_idx += int(FPS * shift) | |
| end_frame_idx += int(FPS * shift) | |
| frames = [Image.open(os.path.join( | |
| frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
| range(start_frame_idx, end_frame_idx)] | |
| # Sample condition | |
| if torch.all(torch.bernoulli(self.p_outside_cond) == 1.): | |
| # Sample condition from outside video | |
| all_idx = set(list(range(len(self.dataset)))) | |
| all_idx.remove(idx) | |
| cond_video_idx = self.dataset[sample(all_idx, k=1)[0]] | |
| cond_video, cond_start_idx = cond_video_idx.split('_') | |
| else: | |
| cond_video = video | |
| video_hits_idx = copy.copy(self.video2indexes[video]) | |
| video_hits_idx.remove(start_idx) | |
| cond_start_idx = sample(video_hits_idx, k=1)[0] | |
| cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx) | |
| cond_spec_path = self.video_idx2path[cond_video_idx] | |
| cond_spec = np.load(cond_spec_path) # (80, 860) | |
| cond_video, cond_start_idx = cond_video_idx.split('_') | |
| cond_frame_path = os.path.join(self.frame_path, cond_video, 'frames') | |
| cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR) | |
| cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L) | |
| if self.rand_shift: | |
| cond_shift = random.uniform(0, 0.5) | |
| cond_spec_shift = int(cond_shift * cond_spec.shape[1] // 10) | |
| # Since only the first second is used | |
| cond_spec = np.roll(cond_spec, -cond_spec_shift, 1) | |
| cond_start_frame_idx += int(FPS * cond_shift) | |
| cond_end_frame_idx += int(FPS * cond_shift) | |
| cond_frames = [Image.open(os.path.join( | |
| cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
| range(cond_start_frame_idx, cond_end_frame_idx)] | |
| # concat spec outside dataload | |
| item['image'] = 2 * spec - 1 # (80, 860) | |
| item['cond_image'] = 2 * cond_spec - 1 # (80, 860) | |
| item['image'] = item['image'][:, :self.spec_take_first] | |
| item['cond_image'] = item['cond_image'][:, :self.spec_take_first] | |
| item['file_path_specs_'] = spec_path | |
| item['file_path_cond_specs_'] = cond_spec_path | |
| if self.frame_transforms is not None: | |
| cond_frames = self.frame_transforms(cond_frames) | |
| frames = self.frame_transforms(frames) | |
| item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) | |
| item['file_path_feats_'] = (frame_path, start_frame_idx) | |
| item['file_path_cond_feats_'] = (cond_frame_path, cond_start_frame_idx) | |
| item['label'] = self.video_idx2label[video_idx] | |
| item['target'] = self.label2target[item['label']] | |
| if self.spec_transforms is not None: | |
| item = self.spec_transforms(item) | |
| except Exception: | |
| print(sys.exc_info()[2]) | |
| print('!!!!!!!!!!!!!!!!!!!!', video_idx, cond_video_idx) | |
| print('!!!!!!!!!!!!!!!!!!!!', end_frame_idx, cond_end_frame_idx) | |
| exit(1) | |
| return item | |
| def validate_data(self): | |
| original_len = len(self.dataset) | |
| valid_dataset = [] | |
| for video_idx in tqdm(self.dataset): | |
| video, start_idx = video_idx.split('_') | |
| frame_path = os.path.join(self.frame_path, video, 'frames') | |
| start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
| end_frame_idx = non_negative(start_frame_idx + FPS * (self.L + 0.6)) | |
| if os.path.exists(os.path.join(frame_path, f'frame{end_frame_idx:0>6d}.jpg')): | |
| valid_dataset.append(video_idx) | |
| else: | |
| self.video2indexes[video].remove(start_idx) | |
| for video_idx in valid_dataset: | |
| video, start_idx = video_idx.split('_') | |
| if len(self.video2indexes[video]) == 1: | |
| valid_dataset.remove(video_idx) | |
| if original_len != len(valid_dataset): | |
| print(f'Validated dataset with enough frames: {len(valid_dataset)}') | |
| self.dataset = valid_dataset | |
| split_clip_ids_path = os.path.join(self.splits_path, f'greatesthit_{self.split}_{self.L:.2f}.json') | |
| if not os.path.exists(split_clip_ids_path): | |
| with open(split_clip_ids_path, 'w') as f: | |
| json.dump(valid_dataset, f) | |
| def make_split_files(self, ratio=[0.85, 0.1, 0.05]): | |
| random.seed(1337) | |
| print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
| # The downloaded videos (some went missing on YouTube and no longer available) | |
| available_mel_paths = set(glob(os.path.join(self.specs_dir, '*_mel.npy'))) | |
| self.available_video_hit = [vh for vh in self.available_video_hit if self.video_idx2path[vh] in available_mel_paths] | |
| all_video = list(self.video2indexes.keys()) | |
| print(f'The number of clips available after download: {len(self.available_video_hit)}') | |
| print(f'The number of videos available after download: {len(all_video)}') | |
| available_idx = list(range(len(all_video))) | |
| random.shuffle(available_idx) | |
| assert sum(ratio) == 1. | |
| cut_train = int(ratio[0] * len(all_video)) | |
| cut_test = cut_train + int(ratio[1] * len(all_video)) | |
| train_idx = available_idx[:cut_train] | |
| test_idx = available_idx[cut_train:cut_test] | |
| valid_idx = available_idx[cut_test:] | |
| train_video = [all_video[i] for i in train_idx] | |
| test_video = [all_video[i] for i in test_idx] | |
| valid_video = [all_video[i] for i in valid_idx] | |
| train_video_hit = [] | |
| for v in train_video: | |
| train_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]] | |
| test_video_hit = [] | |
| for v in test_video: | |
| test_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]] | |
| valid_video_hit = [] | |
| for v in valid_video: | |
| valid_video_hit += [get_GH_data_identifier(v, hit_idx) for hit_idx in self.video2indexes[v]] | |
| # mix train and valid for better validation loss | |
| mixed = train_video_hit + valid_video_hit | |
| random.shuffle(mixed) | |
| split = int(len(mixed) * ratio[0] / (ratio[0] + ratio[2])) | |
| train_video_hit = mixed[:split] | |
| valid_video_hit = mixed[split:] | |
| with open(os.path.join(self.splits_path, 'greatesthit_train.json'), 'w') as train_file,\ | |
| open(os.path.join(self.splits_path, 'greatesthit_test.json'), 'w') as test_file,\ | |
| open(os.path.join(self.splits_path, 'greatesthit_valid.json'), 'w') as valid_file: | |
| json.dump(train_video_hit, train_file) | |
| json.dump(test_video_hit, test_file) | |
| json.dump(valid_video_hit, valid_file) | |
| print(f'Put {len(train_idx)} clips to the train set and saved it to ./data/greatesthit_train.json') | |
| print(f'Put {len(test_idx)} clips to the test set and saved it to ./data/greatesthit_test.json') | |
| print(f'Put {len(valid_idx)} clips to the valid set and saved it to ./data/greatesthit_valid.json') | |
| class CondGreatestHitSpecsCondOnImageTrain(CondGreatestHitSpecsCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| train_transforms = transforms.Compose([ | |
| Resize3D(256), | |
| RandomResizedCrop3D(224, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.1, saturation=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
| class CondGreatestHitSpecsCondOnImageValidation(CondGreatestHitSpecsCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| valid_transforms = transforms.Compose([ | |
| Resize3D(256), | |
| CenterCrop3D(224), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
| class CondGreatestHitSpecsCondOnImageTest(CondGreatestHitSpecsCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| test_transforms = transforms.Compose([ | |
| Resize3D(256), | |
| CenterCrop3D(224), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
| class CondGreatestHitWaveCondOnImage(torch.utils.data.Dataset): | |
| def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len, | |
| L=2.0, frame_transforms=None, splits_path='./data', | |
| data_path='data/greatesthit/greatesthit-process-resized', | |
| p_outside_cond=0., p_audio_aug=0.5, rand_shift=True): | |
| super().__init__() | |
| self.split = split | |
| self.wav_dir = wav_dir | |
| self.frame_transforms = frame_transforms | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.spec_len = spec_len | |
| self.L = L | |
| self.rand_shift = rand_shift | |
| self.p_outside_cond = torch.tensor(p_outside_cond) | |
| split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
| if not os.path.exists(split_clip_ids_path): | |
| raise NotImplementedError() | |
| clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
| video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit])) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name} | |
| self.dataset = clip_video_hit | |
| self.video2indexes = {} | |
| for video_idx in self.dataset: | |
| video, start_idx = video_idx.split('_') | |
| if video not in self.video2indexes.keys(): | |
| self.video2indexes[video] = [] | |
| self.video2indexes[video].append(start_idx) | |
| for video in self.video2indexes.keys(): | |
| if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
| self.dataset.remove( | |
| get_GH_data_identifier(video, self.video2indexes[video][0]) | |
| ) | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| Padding(target_len=int(SR * self.L)), | |
| ]) | |
| if self.frame_transforms == None: | |
| self.frame_transforms = transforms.Compose([ | |
| Resize3D(256), | |
| RandomResizedCrop3D(224, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.1, saturation=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video_idx = self.dataset[idx] | |
| video, start_idx = video_idx.split('_') | |
| start_idx = int(start_idx) | |
| frame_path = os.path.join(self.data_path, video, 'frames') | |
| start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
| if self.rand_shift: | |
| shift = random.uniform(-0.5, 0.5) | |
| start_frame_idx = non_negative(start_frame_idx + int(FPS * shift)) | |
| start_idx = non_negative(start_idx + int(SR * shift)) | |
| if start_frame_idx > self.video_frame_cnt[video] - self.left_over: | |
| start_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| start_idx = non_negative(SR * (start_frame_idx / FPS)) | |
| end_frame_idx = non_negative(start_frame_idx + FPS * self.L) | |
| # target | |
| wave_path = self.video_audio_path[video] | |
| frames = [Image.open(os.path.join( | |
| frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in | |
| range(start_frame_idx, end_frame_idx)] | |
| wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx) | |
| assert sr == SR | |
| wav = self.wav_transforms(wav) | |
| # cond | |
| if torch.all(torch.bernoulli(self.p_outside_cond) == 1.): | |
| all_idx = set(list(range(len(self.dataset)))) | |
| all_idx.remove(idx) | |
| cond_video_idx = self.dataset[sample(all_idx, k=1)[0]] | |
| cond_video, cond_start_idx = cond_video_idx.split('_') | |
| else: | |
| cond_video = video | |
| video_hits_idx = copy.copy(self.video2indexes[video]) | |
| if str(start_idx) in video_hits_idx: | |
| video_hits_idx.remove(str(start_idx)) | |
| cond_start_idx = sample(video_hits_idx, k=1)[0] | |
| cond_video_idx = get_GH_data_identifier(cond_video, cond_start_idx) | |
| cond_video, cond_start_idx = cond_video_idx.split('_') | |
| cond_start_idx = int(cond_start_idx) | |
| cond_frame_path = os.path.join(self.data_path, cond_video, 'frames') | |
| cond_start_frame_idx = non_negative(FPS * int(cond_start_idx)/SR) | |
| cond_wave_path = self.video_audio_path[cond_video] | |
| if self.rand_shift: | |
| cond_shift = random.uniform(-0.5, 0.5) | |
| cond_start_frame_idx = non_negative(cond_start_frame_idx + int(FPS * cond_shift)) | |
| cond_start_idx = non_negative(cond_start_idx + int(shift * SR)) | |
| if cond_start_frame_idx > self.video_frame_cnt[cond_video] - self.left_over: | |
| cond_start_frame_idx = self.video_frame_cnt[cond_video] - self.left_over | |
| cond_start_idx = non_negative(SR * (cond_start_frame_idx / FPS)) | |
| cond_end_frame_idx = non_negative(cond_start_frame_idx + FPS * self.L) | |
| cond_frames = [Image.open(os.path.join( | |
| cond_frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in | |
| range(cond_start_frame_idx, cond_end_frame_idx)] | |
| cond_wav, _ = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_start_idx) | |
| cond_wav = self.wav_transforms(cond_wav) | |
| item['image'] = wav # (44100,) | |
| item['cond_image'] = cond_wav # (44100,) | |
| item['file_path_wav_'] = wave_path | |
| item['file_path_cond_wav_'] = cond_wave_path | |
| if self.frame_transforms is not None: | |
| cond_frames = self.frame_transforms(cond_frames) | |
| frames = self.frame_transforms(frames) | |
| item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) | |
| item['file_path_feats_'] = (frame_path, start_idx) | |
| item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx) | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| def validate_data(self): | |
| raise NotImplementedError() | |
| def make_split_files(self, ratio=[0.85, 0.1, 0.05]): | |
| random.seed(1337) | |
| print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
| all_video = sorted(os.listdir(self.data_path)) | |
| print(f'The number of videos available after download: {len(all_video)}') | |
| available_idx = list(range(len(all_video))) | |
| random.shuffle(available_idx) | |
| assert sum(ratio) == 1. | |
| cut_train = int(ratio[0] * len(all_video)) | |
| cut_test = cut_train + int(ratio[1] * len(all_video)) | |
| train_idx = available_idx[:cut_train] | |
| test_idx = available_idx[cut_train:cut_test] | |
| valid_idx = available_idx[cut_test:] | |
| train_video = [all_video[i] for i in train_idx] | |
| test_video = [all_video[i] for i in test_idx] | |
| valid_video = [all_video[i] for i in valid_idx] | |
| with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\ | |
| open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\ | |
| open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file: | |
| json.dump(train_video, train_file) | |
| json.dump(test_video, test_file) | |
| json.dump(valid_video, valid_file) | |
| print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json') | |
| print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json') | |
| print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json') | |
| class CondGreatestHitWaveCondOnImageTrain(CondGreatestHitWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| train_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
| class CondGreatestHitWaveCondOnImageValidation(CondGreatestHitWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| valid_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
| class CondGreatestHitWaveCondOnImageTest(CondGreatestHitWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| test_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
| class GreatestHitWaveCondOnImage(torch.utils.data.Dataset): | |
| def __init__(self, split, wav_dir, spec_len, random_crop, mel_num, spec_crop_len, | |
| L=2.0, frame_transforms=None, splits_path='./data', | |
| data_path='data/greatesthit/greatesthit-process-resized', | |
| p_outside_cond=0., p_audio_aug=0.5, rand_shift=True): | |
| super().__init__() | |
| self.split = split | |
| self.wav_dir = wav_dir | |
| self.frame_transforms = frame_transforms | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.spec_len = spec_len | |
| self.L = L | |
| self.rand_shift = rand_shift | |
| self.p_outside_cond = torch.tensor(p_outside_cond) | |
| split_clip_ids_path = os.path.join(splits_path, f'greatesthit_{split}.json') | |
| if not os.path.exists(split_clip_ids_path): | |
| raise NotImplementedError() | |
| clip_video_hit = json.load(open(split_clip_ids_path, 'r')) | |
| video_name = list(set([vidx.split('_')[0] for vidx in clip_video_hit])) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames')))//2 for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_denoised_resampled.wav') for v in video_name} | |
| self.dataset = clip_video_hit | |
| self.video2indexes = {} | |
| for video_idx in self.dataset: | |
| video, start_idx = video_idx.split('_') | |
| if video not in self.video2indexes.keys(): | |
| self.video2indexes[video] = [] | |
| self.video2indexes[video].append(start_idx) | |
| for video in self.video2indexes.keys(): | |
| if len(self.video2indexes[video]) == 1: # given video contains only one hit | |
| self.dataset.remove( | |
| get_GH_data_identifier(video, self.video2indexes[video][0]) | |
| ) | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| Padding(target_len=int(SR * self.L)), | |
| ]) | |
| if self.frame_transforms == None: | |
| self.frame_transforms = transforms.Compose([ | |
| Resize3D(256), | |
| RandomResizedCrop3D(224, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.1, saturation=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video_idx = self.dataset[idx] | |
| video, start_idx = video_idx.split('_') | |
| start_idx = int(start_idx) | |
| frame_path = os.path.join(self.data_path, video, 'frames') | |
| start_frame_idx = non_negative(FPS * int(start_idx)/SR) | |
| if self.rand_shift: | |
| shift = random.uniform(-0.5, 0.5) | |
| start_frame_idx = non_negative(start_frame_idx + int(FPS * shift)) | |
| start_idx = non_negative(start_idx + int(SR * shift)) | |
| if start_frame_idx > self.video_frame_cnt[video] - self.left_over: | |
| start_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| start_idx = non_negative(SR * (start_frame_idx / FPS)) | |
| end_frame_idx = non_negative(start_frame_idx + FPS * self.L) | |
| # target | |
| wave_path = self.video_audio_path[video] | |
| frames = [Image.open(os.path.join( | |
| frame_path, f'frame{i+1:0>6d}')).convert('RGB') for i in | |
| range(start_frame_idx, end_frame_idx)] | |
| wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_idx) | |
| assert sr == SR | |
| wav = self.wav_transforms(wav) | |
| item['image'] = wav # (44100,) | |
| item['file_path_wav_'] = wave_path | |
| if self.frame_transforms is not None: | |
| frames = self.frame_transforms(frames) | |
| item['feature'] = torch.stack(frames, dim=0) # (15 * L, 112, 112, 3) | |
| item['file_path_feats_'] = (frame_path, start_idx) | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| def validate_data(self): | |
| raise NotImplementedError() | |
| def make_split_files(self, ratio=[0.85, 0.1, 0.05]): | |
| random.seed(1337) | |
| print(f'The split files do not exist @ {self.splits_path}. Calculating the new ones.') | |
| all_video = sorted(os.listdir(self.data_path)) | |
| print(f'The number of videos available after download: {len(all_video)}') | |
| available_idx = list(range(len(all_video))) | |
| random.shuffle(available_idx) | |
| assert sum(ratio) == 1. | |
| cut_train = int(ratio[0] * len(all_video)) | |
| cut_test = cut_train + int(ratio[1] * len(all_video)) | |
| train_idx = available_idx[:cut_train] | |
| test_idx = available_idx[cut_train:cut_test] | |
| valid_idx = available_idx[cut_test:] | |
| train_video = [all_video[i] for i in train_idx] | |
| test_video = [all_video[i] for i in test_idx] | |
| valid_video = [all_video[i] for i in valid_idx] | |
| with open(os.path.join(self.splits_path, 'greatesthit_video_train.json'), 'w') as train_file,\ | |
| open(os.path.join(self.splits_path, 'greatesthit_video_test.json'), 'w') as test_file,\ | |
| open(os.path.join(self.splits_path, 'greatesthit_video_valid.json'), 'w') as valid_file: | |
| json.dump(train_video, train_file) | |
| json.dump(test_video, test_file) | |
| json.dump(valid_video, valid_file) | |
| print(f'Put {len(train_idx)} videos to the train set and saved it to ./data/greatesthit_video_train.json') | |
| print(f'Put {len(test_idx)} videos to the test set and saved it to ./data/greatesthit_video_test.json') | |
| print(f'Put {len(valid_idx)} videos to the valid set and saved it to ./data/greatesthit_video_valid.json') | |
| class GreatestHitWaveCondOnImageTrain(GreatestHitWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| train_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
| class GreatestHitWaveCondOnImageValidation(GreatestHitWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| valid_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
| class GreatestHitWaveCondOnImageTest(GreatestHitWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| test_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
| def draw_spec(spec, dest, cmap='magma'): | |
| plt.imshow(spec, cmap=cmap, origin='lower') | |
| plt.axis('off') | |
| plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300) | |
| plt.close() | |
| if __name__ == '__main__': | |
| import sys | |
| from omegaconf import OmegaConf | |
| # cfg = OmegaConf.load('configs/greatesthit_transformer_with_vNet_randshift_2s_GH_vqgan_no_earlystop.yaml') | |
| cfg = OmegaConf.load('configs/greatesthit_codebook.yaml') | |
| data = instantiate_from_config(cfg.data) | |
| data.prepare_data() | |
| data.setup() | |
| print(len(data.datasets['train'])) | |
| print(data.datasets['train'][24]) | |