Spaces:
Runtime error
Runtime error
| import functools | |
| import io | |
| import json | |
| import math | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
| import random | |
| import re | |
| import string | |
| import subprocess | |
| import sys | |
| import yaml | |
| import numpy as np | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from pydub import AudioSegment | |
| from tqdm import tqdm | |
| import torch | |
| import torchvision | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset, get_worker_info | |
| from torch.utils.data.distributed import DistributedSampler | |
| from transformers import AutoTokenizer | |
| import librosa | |
| import soundfile as sf | |
| EMOTION_MAP_DICT = { | |
| 'amused': 'amused' , | |
| 'anger': 'angry' , 'angry': 'angry' , | |
| 'anxious': 'anxious' , | |
| 'apologetic': 'apologetic' , | |
| 'assertive': 'assertive' , | |
| 'calm': 'calm' , | |
| 'concerned': 'concerned' , | |
| 'contempt': 'contempt' , | |
| 'disgust': 'disgusted' , 'disgusted': 'disgusted' , | |
| 'encouraging': 'encouraging' , | |
| 'excited': 'excited' , | |
| 'fear': 'fearful' , 'fearful': 'fearful' , | |
| 'frustated': 'frustated' , | |
| 'happy': 'happy' , 'joy': 'happy' , | |
| 'neutral': 'neutral' , | |
| 'sad': 'sad' , 'sadness': 'sad' , | |
| 'sleepy': 'sleepy' , | |
| 'surprise': 'surprised' , 'surprised': 'surprised' , | |
| 'pleasantly surprised': 'pleasantly surprised' , | |
| } | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| class DataCollator: | |
| def __init__(self, tokenizer, clap_config): | |
| self.tokenizer = tokenizer | |
| self.clap_config = clap_config | |
| self.max_num_window = clap_config["max_num_window"] | |
| def __call__(self, batch): | |
| filenames, audio_clips, audio_embed_masks, input_ids, attention_masks = zip(*batch) | |
| num_windows_all = [sum(audio_embed_mask) for audio_embed_mask in audio_embed_masks] | |
| max_window_batch = int(max(num_windows_all)) | |
| if max_window_batch > self.max_num_window: | |
| max_window_batch = self.max_num_window | |
| padded_audio_clips = [] | |
| padded_audio_embed_masks = [] | |
| for audio_clip, audio_embed_mask in zip(audio_clips,audio_embed_masks): | |
| this_audio_clip_clips = [clip for clip in audio_clip] | |
| num_windows = len(this_audio_clip_clips) | |
| if num_windows < max_window_batch: | |
| for _ in range(max_window_batch - num_windows): | |
| this_audio_clip_clips.append(torch.zeros_like(this_audio_clip_clips[-1])) | |
| audio_clip = torch.cat(this_audio_clip_clips) | |
| audio_embed_mask = torch.zeros(max_window_batch) | |
| audio_embed_mask[:num_windows] = 1 | |
| elif num_windows < max_window_batch: | |
| audio_clip = this_audio_clip_clips[:max_window_batch] | |
| audio_clip = torch.cat(this_audio_clip_clips) | |
| audio_embed_mask = audio_embed_mask[:max_window_batch] | |
| else: | |
| audio_clip = torch.cat(this_audio_clip_clips) | |
| padded_audio_clips.append(audio_clip) | |
| padded_audio_embed_masks.append(audio_embed_mask) | |
| audio_clips = torch.cat([x.unsqueeze(0) for x in padded_audio_clips], dim=0) | |
| audio_embed_mask = torch.cat([x.unsqueeze(0) for x in padded_audio_embed_masks], dim=0) | |
| max_length = max([ids.shape[1] for ids in input_ids]) | |
| padded_input_ids = [] | |
| padded_attention_masks = [] | |
| for ids, mask in zip(input_ids, attention_masks): | |
| if ids.shape[1] < max_length: | |
| padded_input_ids.append( | |
| torch.cat([ids, torch.LongTensor([self.tokenizer.pad_token_id] * (max_length - ids.shape[1])).unsqueeze(0)], dim=1) | |
| ) | |
| padded_attention_masks.append( | |
| torch.cat([mask, torch.LongTensor([0] * (max_length - mask.shape[1])).unsqueeze(0)], dim=1) | |
| ) | |
| else: | |
| padded_input_ids.append(ids) | |
| padded_attention_masks.append(mask) | |
| padded_input_ids = torch.cat(padded_input_ids, dim=0) | |
| padded_attention_masks = torch.cat(padded_attention_masks, dim=0).bool() | |
| out_dict = dict( | |
| filenames=filenames, | |
| audio_clips=audio_clips, | |
| audio_embed_mask=audio_embed_mask, | |
| input_ids=padded_input_ids, | |
| attention_mask=padded_attention_masks | |
| ) | |
| return out_dict | |
| class AudioTextData(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| dataset_file_root: str, | |
| data_root: str, | |
| clap_config: dict, | |
| dataset_blending_global_weight: float, | |
| dataset_blending_config: dict, | |
| dataset_blending_output: str, | |
| tokenizer, | |
| max_tokens: int, | |
| split: str = 'train', | |
| valid_dataset_config: dict = {}, | |
| valid_dataset_name: str = '', | |
| epoch: int = 0, | |
| force_reblend: bool = False, | |
| sr = 16000, | |
| **kwargs | |
| ): | |
| self.dataset_file_root = dataset_file_root | |
| self.data_root = data_root | |
| self.clap_config = clap_config | |
| self.dataset_blending_global_weight = dataset_blending_global_weight | |
| self.dataset_blending_config = dataset_blending_config | |
| self.sr = sr | |
| self.split = split | |
| self.epoch = epoch | |
| self.force_reblend = force_reblend | |
| assert self.split in ['train', 'val', 'test'] | |
| if self.split == 'train': | |
| self.data = self.blend_dataset(dataset_blending_config, dataset_blending_output) | |
| elif self.split in ['val', 'test']: | |
| self.valid_data = self.validation_dataset(valid_dataset_config, valid_dataset_name) | |
| self.tokenizer = tokenizer | |
| self.tokenizer.padding_side = "right" | |
| self.max_tokens = max_tokens | |
| def shuffle_dict_fixed_rand(dic, seed=0): | |
| print('randomly shuffling key-value pairs') | |
| local_random = np.random.default_rng(seed) | |
| original_keys = list(dic.keys()) | |
| shuffled_keys = deepcopy(original_keys) | |
| local_random.shuffle(shuffled_keys) | |
| shuffling_mapping = {x: y for (x, y) in zip(original_keys, shuffled_keys)} | |
| shuffled_dic = {} | |
| for idx in original_keys: | |
| shuffled_idx = shuffling_mapping[idx] | |
| shuffled_dic[idx] = dic[shuffled_idx] | |
| return shuffled_dic | |
| def is_broken_file(audiopath): | |
| BROKEN_FILES = [ | |
| "/lustre/fsw/portfolios/adlr/users/zkong/datasets/FMA/fma_large/023/023431.mp3", | |
| "/lustre/fsw/portfolios/adlr/users/zkong/datasets/FMA/fma_large/033/033690.mp3", | |
| "/lustre/fsw/portfolios/adlr/users/zkong/datasets/FMA/fma_large/119/119217.mp3", | |
| "/lustre/fsw/portfolios/adlr/users/zkong/datasets/FMA/fma_large/119/119222.mp3", | |
| "/lustre/fsw/portfolios/adlr/users/zkong/datasets/FMA/fma_large/119/119219.mp3", | |
| "/lustre/fsw/portfolios/adlr/users/zkong/datasets/GTZAN/gtzan/data/genres/jazz/jazz.00054.wav" | |
| ] | |
| return audiopath in BROKEN_FILES | |
| def _read_dataset_file(self, dataset_file): | |
| print("reading", dataset_file) | |
| with open(dataset_file) as f: | |
| contents = f.read() | |
| contents = json.loads(contents) | |
| if contents['split_path'] is not None: | |
| abs_path = contents['split_path'] | |
| """ | |
| for normal data | |
| contents['data'] = {idx: { | |
| 'name': rel_path/name, | |
| 'prompt': prompt, | |
| 'output': output, | |
| [optional] 'audio_start': audio_start, | |
| 'task': task, | |
| }} | |
| """ | |
| if 'interleaved' not in dataset_file: | |
| for idx in contents["data"]: | |
| contents["data"][idx]['task'] = contents["flamingo_task"] | |
| contents["data"][idx]['name'] = os.path.join( | |
| abs_path, contents["data"][idx]['name'] | |
| ) | |
| return contents | |
| def blend_dataset(self, dataset_blending_config, dataset_blending_output): | |
| if os.path.exists(dataset_blending_output) and not self.force_reblend: | |
| print("loading blended dataset file from:", dataset_blending_output) | |
| with open(dataset_blending_output) as f: | |
| contents = f.read() | |
| self_data = json.loads(contents) | |
| else: | |
| if not self.force_reblend: | |
| print("no blended dataset file found; reading all dataset files") | |
| else: | |
| print("force reblending dataset at epoch {}; reading all dataset files".format(self.epoch)) | |
| all_data = {} | |
| for dataset_name in dataset_blending_config: | |
| dataset_file = os.path.join(self.dataset_file_root, '{}.json'.format(dataset_name)) | |
| contents = self._read_dataset_file(dataset_file) | |
| contents['data'] = self.shuffle_dict_fixed_rand( | |
| contents['data'], | |
| seed=sum(list(map(ord, dataset_name))) | |
| ) | |
| weight_global = float(self.dataset_blending_global_weight) | |
| weight_dataset = float(dataset_blending_config[dataset_name]["weight"]) | |
| weight = weight_global * weight_dataset | |
| all_data[dataset_name] = { | |
| "contents": contents, | |
| "weight": weight | |
| } | |
| self_data = { | |
| "dataset_path": self.data_root, | |
| "split_path": None, | |
| "total_num": 0, | |
| "data": {} # {id: {'name': rel_path/name or [rel_path/names], 'prompt': prompt or [prompts], 'output': output or [outputs], 'task': task, 'interleaved': interleave_method}} | |
| } | |
| for dataset_name in all_data: | |
| print('blending {}'.format(dataset_name)) | |
| contents = all_data[dataset_name]["contents"] | |
| shuffled_contents_data = contents['data'] | |
| weight = all_data[dataset_name]["weight"] | |
| assert type(weight) == float and weight > 0.0 | |
| dataset_total_num = contents['total_num'] | |
| start_idx = int(self.epoch * dataset_total_num * weight) | |
| end_idx = int((self.epoch + 1) * dataset_total_num * weight) | |
| for idx in range(start_idx, end_idx): | |
| if idx > 0 and idx % dataset_total_num == 0: | |
| print('force shuffling at new epoch {} for dataset {}'.format(idx // dataset_total_num, dataset_name)) | |
| shuffled_contents_data = self.shuffle_dict_fixed_rand( | |
| contents['data'], | |
| seed=sum(list(map(ord, '{}-epoch-{}'.format(dataset_name, idx // dataset_total_num)))) | |
| ) | |
| key = str(idx % dataset_total_num) | |
| item = shuffled_contents_data[key] | |
| found_broken = False | |
| if type(item['name']) is str: | |
| audiopath = item['name'] | |
| if self.is_broken_file(audiopath): | |
| print('cannot read {}'.format(audiopath)) | |
| found_broken = True | |
| if found_broken: | |
| continue | |
| self_data['data'][self_data['total_num']] = item | |
| self_data['total_num'] += 1 | |
| if not self.force_reblend: | |
| print('writing blended dataset file to:', dataset_blending_output) | |
| with open(dataset_blending_output, 'w') as json_file: | |
| json.dump(self_data, json_file) | |
| else: | |
| print('writing reblended dataset file to:', dataset_blending_output.replace('.json', '-reblended.json')) | |
| with open(dataset_blending_output.replace('.json', '-reblended.json'), 'w') as json_file: | |
| json.dump(self_data, json_file) | |
| return self_data | |
| def get_num_windows(self, T, sr): | |
| clap_config = self.clap_config | |
| window_length = int(float(clap_config["window_length"]) * sr) | |
| window_overlap = int(float(clap_config["window_overlap"]) * sr) | |
| max_num_window = int(clap_config["max_num_window"]) | |
| num_windows = 1 | |
| if T <= window_length: | |
| num_windows = 1 | |
| full_length = window_length | |
| elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): | |
| num_windows = max_num_window | |
| full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) | |
| else: | |
| num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) | |
| full_length = num_windows * window_length - (num_windows - 1) * window_overlap | |
| return num_windows, full_length | |
| def load_audio(self, file_path, target_sr=16000, duration=30.0, start=0.0): | |
| if file_path.endswith('.mp3'): | |
| audio = AudioSegment.from_file(file_path) | |
| if len(audio) > (start + duration) * 1000: | |
| audio = audio[start * 1000:(start + duration) * 1000] | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| data = np.array(audio.get_array_of_samples()) | |
| if audio.sample_width == 2: | |
| data = data.astype(np.float32) / np.iinfo(np.int16).max | |
| elif audio.sample_width == 4: | |
| data = data.astype(np.float32) / np.iinfo(np.int32).max | |
| else: | |
| raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
| else: | |
| with sf.SoundFile(file_path) as audio: | |
| original_sr = audio.samplerate | |
| channels = audio.channels | |
| max_frames = int((start + duration) * original_sr) | |
| audio.seek(int(start * original_sr)) | |
| frames_to_read = min(max_frames, len(audio)) | |
| data = audio.read(frames_to_read) | |
| if data.max() > 1 or data.min() < -1: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| if original_sr != target_sr: | |
| if channels == 1: | |
| data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
| else: | |
| data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
| else: | |
| if channels != 1: | |
| data = data.T[0] | |
| if data.min() >= 0: | |
| data = 2 * data / abs(data.max()) - 1.0 | |
| else: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| assert len(data.shape) == 1, data.shape | |
| return data | |
| def compute_sliding_window(self, audio_file, audio_start=0.0, audio="sound"): | |
| if type(audio_start) == str: | |
| audio_start = float(audio_start) | |
| if audio == "sound": | |
| encoder_config = self.clap_config | |
| else: | |
| raise NotImplementedError | |
| if encoder_config["method"] == 'nvclap-large': | |
| sr = 16000 | |
| else: | |
| raise NotImplementedError | |
| window_length = int(float(encoder_config["window_length"]) * sr) | |
| window_overlap = int(float(encoder_config["window_overlap"]) * sr) | |
| max_num_window = int(encoder_config["max_num_window"]) | |
| duration = max_num_window * (encoder_config["window_length"] - encoder_config["window_overlap"]) + encoder_config["window_overlap"] | |
| audio_data = self.load_audio(os.path.join(self.data_root, audio_file), sr, duration, audio_start) # already cuts to max duration | |
| T = len(audio_data) | |
| num_windows, full_length = self.get_num_windows(T, sr) | |
| # pads to the nearest multiple of window_length | |
| if full_length > T: | |
| audio_data = np.append(audio_data, np.zeros(full_length - T)) | |
| audio_data = audio_data.reshape(1, -1) | |
| audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() | |
| audio_clips = [] | |
| audio_embed_mask = torch.ones(num_windows) | |
| for i in range(num_windows): | |
| start = i * (window_length - window_overlap) | |
| audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] | |
| audio_clips.append(audio_data_tensor_this) | |
| return audio_clips, audio_embed_mask | |
| def validation_dataset(self, valid_dataset_config, valid_dataset_name): | |
| dataset_file = os.path.join(self.dataset_file_root, '{}.json'.format(valid_dataset_name)) | |
| contents = self._read_dataset_file(dataset_file) | |
| contents['data'] = self.shuffle_dict_fixed_rand( | |
| contents['data'], | |
| seed=sum(list(map(ord, valid_dataset_name))) | |
| ) | |
| return contents | |
| def preprocess_string_for_eval(self, x): | |
| x = x.rstrip().lstrip() | |
| x = x.lower() | |
| return x | |
| def _actual_getitem(self, i): | |
| if self.split == 'train': | |
| try: | |
| item = self.data['data'][str(i)] | |
| except: | |
| item = self.data['data'][i] | |
| if type(item['name']) is str: | |
| audio_file = item['name'] | |
| audio_start = 0 if 'audio_start' not in item else float(item['audio_start']) | |
| else: | |
| raise Exception(f"The item has a {type(item['name'])}. Only single path as a string is supported") | |
| # compute window for long audios | |
| audio_clips, audio_embed_mask = self.compute_sliding_window(audio_file, audio_start, audio="sound") | |
| # make the text prompt | |
| text_prompt = str(item['prompt']).lower() | |
| text_output = str(item['output']).lower() | |
| sample = f"<audio>{text_prompt.strip()}{self.tokenizer.sep_token}{text_output.strip()}<|endofchunk|>{self.tokenizer.eos_token}" | |
| text = self.tokenizer( | |
| sample, | |
| max_length=self.max_tokens, | |
| padding="longest", | |
| truncation="only_first", | |
| return_tensors="pt" | |
| ) | |
| elif self.split in ['val', 'test']: | |
| try: | |
| item = self.valid_data['data'][str(i)] | |
| except: | |
| item = self.valid_data['data'][i] | |
| if type(item['name']) is str: | |
| audio_file = os.path.join(self.data_root, item['name']) | |
| audio_start = 0 if 'audio_start' not in item else float(item['audio_start']) | |
| else: | |
| raise Exception(f"The item has a {type(item['name'])}. Only single path as a string is supported") | |
| # compute window for long audios | |
| audio_clips, audio_embed_mask = self.compute_sliding_window(audio_file, audio_start, audio="sound") | |
| # make the text prompt | |
| text_prompt = self.preprocess_string_for_eval(str(item['prompt']).lower()) | |
| text_output = self.preprocess_string_for_eval(str(item['output']).lower()) | |
| sample = f"<audio>{text_prompt.strip()}{self.tokenizer.sep_token}{text_output.strip()}<|endofchunk|>{self.tokenizer.eos_token}" | |
| text = self.tokenizer( | |
| sample, | |
| max_length=self.max_tokens, | |
| padding="longest", | |
| truncation="only_first", | |
| return_tensors="pt" | |
| ) | |
| # audio_clips_clap, audio_embed_mask_clap, audio_clips_speech, audio_embed_mask_speech, audio_clips_music, audio_embed_mask_music, | |
| return (item['name'], audio_clips, audio_embed_mask, text["input_ids"], text["attention_mask"]) | |
| def __getitem__(self, i): | |
| try: | |
| return self._actual_getitem(i) | |
| except Exception as e: | |
| print('batch {} failed with reason {}'.format(i, e)) | |
| try: | |
| return self._actual_getitem((i-42)%99) | |
| except: | |
| return self._actual_getitem((i-84)%99) | |
| def __len__(self): | |
| if self.split == 'train': | |
| return len(list(self.data['data'].keys())) | |
| elif self.split == 'val': | |
| return min(len(list(self.valid_data['data'].keys())), 64) | |
| elif self.split == 'test': | |
| return len(list(self.valid_data['data'].keys())) | |
| class DataInfo: | |
| dataset: Dataset | |
| dataloader: DataLoader | |
| sampler: DistributedSampler = None | |
| def set_epoch(self, epoch): | |
| if self.sampler is not None and isinstance(self.sampler, DistributedSampler): | |
| self.sampler.set_epoch(epoch) | |
| def get_audiotext_dataloader(data_config, clap_config, text_tokenizer, batch_size, split='train', epoch=0, force_reblend=False): | |
| assert split in ['train', 'val', 'test'] | |
| data_collator = DataCollator(text_tokenizer, clap_config) | |
| dataloader_shuffle = False | |
| if split == 'train': | |
| trainset = AudioTextData( | |
| **data_config, | |
| clap_config=clap_config, | |
| tokenizer=text_tokenizer, | |
| split=split, | |
| epoch=epoch, | |
| force_reblend=force_reblend | |
| ) | |
| sampler = DistributedSampler(trainset, shuffle=True) | |
| trainloader = DataLoader( | |
| trainset, | |
| sampler=sampler, | |
| batch_size=batch_size, | |
| shuffle=dataloader_shuffle, | |
| collate_fn=data_collator, | |
| num_workers=data_config["num_workers"] | |
| ) | |
| return DataInfo(dataset=trainset, dataloader=trainloader, sampler=sampler) | |
| elif split in ['val', 'test']: | |
| all_DataInfo = {} | |
| for valid_dataset_name in list(data_config["valid_dataset_config"].keys()): | |
| valid_dataset_name = valid_dataset_name.strip() | |
| validset = AudioTextData( | |
| **data_config, | |
| clap_config=clap_config, | |
| tokenizer=text_tokenizer, | |
| split=split, | |
| valid_dataset_name=valid_dataset_name | |
| ) | |
| if split == 'val': | |
| # distributed sampler | |
| all_DataInfo[valid_dataset_name] = DataInfo( | |
| dataset=validset, | |
| dataloader=DataLoader( | |
| validset, | |
| sampler=DistributedSampler(validset, shuffle=False), | |
| batch_size=batch_size, | |
| shuffle=dataloader_shuffle, | |
| collate_fn=data_collator, | |
| num_workers=data_config["num_workers"] | |
| )) | |
| else: | |
| # single GPU | |
| all_DataInfo[valid_dataset_name] = DataInfo( | |
| dataset=validset, | |
| dataloader=DataLoader( | |
| validset, | |
| batch_size=batch_size, | |
| shuffle=dataloader_shuffle, | |
| collate_fn=data_collator, | |
| num_workers=data_config["num_workers"] | |
| )) | |
| return all_DataInfo | |
| def main(): | |
| import time | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-c', '--config', type=str, default='../configs/config.yaml', help='yaml config path') | |
| args = parser.parse_args() | |
| config = yaml.load(open(args.config), Loader=yaml.FullLoader) | |
| data_config = config['data_config'] | |
| clap_config = config['clap_config'] | |
| whisper_config = config["whisper_config"] | |
| mert_config = config["mert_config"] | |
| tokenizer_path = "facebook/opt-1.3b" | |
| cache_dir = '/lustre/fsw/portfolios/adlr/users/sreyang/.cache' | |
| text_tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path, | |
| local_files_only=False, | |
| trust_remote_code=True, | |
| cache_dir=cache_dir, | |
| ) | |
| text_tokenizer.add_special_tokens( | |
| {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]} | |
| ) | |
| if text_tokenizer.pad_token is None: | |
| text_tokenizer.add_special_tokens({"pad_token": "<|PAD_TOKEN|>"}) | |
| if text_tokenizer.sep_token is None: | |
| text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
| trainset = AudioTextData( | |
| **data_config, | |
| clap_config=clap_config, tokenizer=text_tokenizer, | |
| epoch=66, force_reblend=True | |
| ) | |
| data_collator = DataCollator(text_tokenizer) | |
| dataloader = DataLoader(trainset, batch_size=16, shuffle=True, collate_fn=data_collator, num_workers=4) | |
| for step, batch in enumerate(dataloader): | |
| filenames = batch["filenames"] | |
| audio_clips = batch["audio_clips"] | |
| audio_embed_mask = batch["audio_embed_mask"] | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| print( | |
| 'batch {}:'.format(step+1), | |
| audio_clips.shape, audio_embed_mask.shape, | |
| input_ids.shape, attention_mask.shape | |
| ) | |
| print('filenames', filenames) | |
| print('audio_embed_mask', audio_embed_mask) | |
| print('input_ids', input_ids) | |
| for input_id in input_ids: | |
| print('-' * 50) | |
| print(text_tokenizer.decode(input_id)) | |
| print('attention_mask', attention_mask) | |
| if step == 20: | |
| break | |
| if __name__ == "__main__": | |
| main() |