Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.utils.data.dataset import Dataset | |
| import torch.distributed as dist | |
| import torchaudio | |
| import torchvision | |
| import torchvision.io | |
| import os, io, csv, math, random | |
| import os.path as osp | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from einops import rearrange | |
| import glob | |
| from decord import VideoReader, AudioReader | |
| import decord | |
| from copy import deepcopy | |
| import pickle | |
| from petrel_client.client import Client | |
| import sys | |
| sys.path.append('./') | |
| from foleycrafter.data import video_transforms | |
| from foleycrafter.utils.util import \ | |
| random_audio_video_clip, get_full_indices, video_tensor_to_np, get_video_frames | |
| from foleycrafter.utils.spec_to_mel import wav_tensor_to_fbank, read_wav_file_io, load_audio, normalize_wav, pad_wav | |
| from foleycrafter.utils.converter import get_mel_spectrogram_from_audio, pad_spec, normalize, normalize_spectrogram | |
| def zero_rank_print(s): | |
| if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True) | |
| def get_mel(audio_data, audio_cfg): | |
| # mel shape: (n_mels, T) | |
| mel = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=audio_cfg["sample_rate"], | |
| n_fft=audio_cfg["window_size"], | |
| win_length=audio_cfg["window_size"], | |
| hop_length=audio_cfg["hop_size"], | |
| center=True, | |
| pad_mode="reflect", | |
| power=2.0, | |
| norm=None, | |
| onesided=True, | |
| n_mels=64, | |
| f_min=audio_cfg["fmin"], | |
| f_max=audio_cfg["fmax"], | |
| ).to(audio_data.device) | |
| mel = mel(audio_data) | |
| # we use log mel spectrogram as input | |
| mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) | |
| return mel # (T, n_mels) | |
| def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): | |
| """ | |
| PARAMS | |
| ------ | |
| C: compression factor | |
| """ | |
| return normalize_fun(torch.clamp(x, min=clip_val) * C) | |
| class CPU_Unpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if module == 'torch.storage' and name == '_load_from_bytes': | |
| return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
| else: | |
| return super().find_class(module, name) | |
| class AudioSetStrong(Dataset): | |
| # read feature and audio | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| self.data_path = 'data/AudioSetStrong/train/feature' | |
| self.data_list = list(self._client.list(self.data_path)) | |
| self.length = len(self.data_list) | |
| # get video feature | |
| self.video_path = 'data/AudioSetStrong/train/video' | |
| vision_transform_list = [ | |
| transforms.Resize((128, 128)), | |
| transforms.CenterCrop((112, 112)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ] | |
| self.video_transform = transforms.Compose(vision_transform_list) | |
| def get_batch(self, idx): | |
| embeds = self.data_list[idx] | |
| mel = embeds['mel'] | |
| save_bsz = mel.shape[0] | |
| audio_info = embeds['audio_info'] | |
| text_embeds = embeds['text_embeds'] | |
| # audio_info['label_list'] = np.array(audio_info['label_list']) | |
| audio_info_array = np.array(audio_info['label_list']) | |
| prompts = [] | |
| for i in range(save_bsz): | |
| prompts.append(', '.join(audio_info_array[i, :audio_info['event_num'][i]].tolist())) | |
| # import ipdb; ipdb.set_trace() | |
| # read videos | |
| videos = None | |
| for video_name in audio_info['audio_name']: | |
| video_bytes = self._client.Get(osp.join(self.video_path, video_name+'.mp4')) | |
| video_bytes = io.BytesIO(video_bytes) | |
| video_reader = VideoReader(video_bytes) | |
| video = video_reader.get_batch(get_full_indices(video_reader)).asnumpy() | |
| video = get_video_frames(video, 150) | |
| video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous().float() | |
| video = self.video_transform(video) | |
| video = video.unsqueeze(0) | |
| if videos is None: | |
| videos = video | |
| else: | |
| videos = torch.cat([videos, video], dim=0) | |
| # video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous() | |
| assert videos is not None, 'no video read' | |
| return mel, audio_info, text_embeds, prompts, videos | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| zero_rank_print(' >>> load error <<<') | |
| idx = random.randint(0, self.length-1) | |
| sample = dict(mel=mel, audio_info=audio_info, text_embeds=text_embeds, prompts=prompts, videos=videos) | |
| return sample | |
| class VGGSound(Dataset): | |
| # read feature and audio | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| self.data_path = 'data/VGGSound/train/video' | |
| self.visual_data_path = 'data/VGGSound/train/feature' | |
| self.embeds_list = glob.glob(f'{self.data_path}/*.pt') | |
| self.visual_list = glob.glob(f'{self.visual_data_path}/*.pt') | |
| self.length = len(self.embeds_list) | |
| def get_batch(self, idx): | |
| embeds = torch.load(self.embeds_list[idx], map_location='cpu') | |
| visual_embeds = torch.load(self.visual_list[idx], map_location='cpu') | |
| # audio_embeds = embeds['audio_embeds'] | |
| visual_embeds = visual_embeds['visual_embeds'] | |
| video_name = embeds['video_name'] | |
| text = embeds['text'] | |
| mel = embeds['mel'] | |
| audio = mel | |
| return visual_embeds, audio, text | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, idx): | |
| while True: | |
| try: | |
| visual_embeds, audio, text = self.get_batch(idx) | |
| break | |
| except Exception as e: | |
| zero_rank_print('load error') | |
| idx = random.randint(0, self.length-1) | |
| sample = dict(visual_embeds=visual_embeds, audio=audio, text=text) | |
| return sample |