|
import sys |
|
|
|
sys.path.append("src") |
|
import os |
|
import math |
|
import pandas as pd |
|
import zlib |
|
import yaml |
|
import qa_mdt.audioldm_train.utilities.audio as Audio |
|
from qa_mdt.audioldm_train.utilities.tools import load_json |
|
from qa_mdt.audioldm_train.dataset_plugin import * |
|
import librosa |
|
from librosa.filters import mel as librosa_mel_fn |
|
import threading |
|
|
|
import random |
|
import lmdb |
|
from torch.utils.data import Dataset |
|
import torch.nn.functional |
|
import torch |
|
from pydub import AudioSegment |
|
import numpy as np |
|
import torchaudio |
|
import io |
|
import json |
|
from .datum_all_pb2 import Datum_all as Datum_lmdb |
|
from .datum_mos_pb2 import Datum_mos as Datum_lmdb_mos |
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): |
|
return torch.log(torch.clamp(x, min=clip_val) * C) |
|
|
|
def dynamic_range_decompression_torch(x, C=1): |
|
return torch.exp(x) / C |
|
|
|
|
|
def spectral_normalize_torch(magnitudes): |
|
output = dynamic_range_compression_torch(magnitudes) |
|
return output |
|
|
|
|
|
def spectral_de_normalize_torch(magnitudes): |
|
output = dynamic_range_decompression_torch(magnitudes) |
|
return output |
|
|
|
|
|
class AudioDataset(Dataset): |
|
def __init__( |
|
self, |
|
config, |
|
lmdb_path, |
|
key_path, |
|
mos_path, |
|
lock=True |
|
): |
|
self.config = config |
|
|
|
""" |
|
Dataset that manages audio recordings |
|
""" |
|
self.pad_wav_start_sample = 0 |
|
self.trim_wav = False |
|
self.build_setting_parameters() |
|
self.build_dsp() |
|
|
|
self.lmdb_path = [_.encode("utf-8") for _ in lmdb_path] |
|
self.lmdb_env = [lmdb.open(_, readonly=True, lock=False) for _ in self.lmdb_path] |
|
self.mos_txn_env = lmdb.open(mos_path, readonly=True, lock=False) |
|
self.key_path = [_.encode("utf-8") for id, _ in enumerate(key_path)] |
|
self.keys = [] |
|
for _ in range(len(key_path)): |
|
with open(self.key_path[_]) as f: |
|
for line in f: |
|
key = line.strip() |
|
self.keys.append((_, key.split()[0].encode('utf-8'))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Dataset initialize finished, dataset_length : {len(self.keys)}") |
|
print(f"Initialize of filter start: ") |
|
with open('filter_all.lst', 'r') as f: |
|
self.filter = {} |
|
for _ in f.readlines(): |
|
self.filter[_.strip()] = 1 |
|
print(f"Initialize of filter finished") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
( |
|
|
|
|
|
fname, |
|
|
|
|
|
waveform, |
|
|
|
|
|
stft, |
|
|
|
|
|
log_mel_spec, |
|
|
|
label_vector, |
|
|
|
|
|
random_start, |
|
|
|
|
|
caption, |
|
|
|
|
|
mos |
|
|
|
) = self.feature_extraction(index) |
|
|
|
data = { |
|
"text": [caption], |
|
"fname": [fname], |
|
|
|
"waveform": "" if (waveform is None) else waveform.float(), |
|
|
|
"stft": "" if (stft is None) else stft.float(), |
|
|
|
"log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), |
|
"duration": self.duration, |
|
"sampling_rate": self.sampling_rate, |
|
"random_start_sample_in_original_audio_file": random_start, |
|
"label_vector": label_vector, |
|
"mos":mos |
|
} |
|
|
|
if data["text"] is None: |
|
print("Warning: The model return None on key text", fname) |
|
data["text"] = "" |
|
|
|
return data |
|
|
|
def __len__(self): |
|
return len(self.keys) |
|
|
|
def feature_extraction(self, index): |
|
if index > len(self.keys) - 1: |
|
print( |
|
"The index of the dataloader is out of range: %s/%s" |
|
% (index, len(self.data)) |
|
) |
|
index = random.randint(0, len(self.keys) - 1) |
|
waveform = np.array([]) |
|
tyu = 0 |
|
flag = 0 |
|
last_index = index |
|
while(flag == 0): |
|
id_, k = self.keys[index] |
|
try: |
|
if self.filter[k.decode()] == 1: |
|
index = random.randint(0, len(self.keys) - 1) |
|
else: |
|
flag = 1 |
|
except: |
|
flag = 1 |
|
index = last_index |
|
while len(waveform) < 1000: |
|
|
|
id_, k = self.keys[index] |
|
with self.lmdb_env[id_].begin(write=False) as txn: |
|
cursor = txn.cursor() |
|
try: |
|
cursor.set_key(k) |
|
|
|
datum_tmp = Datum_lmdb() |
|
datum_tmp.ParseFromString(cursor.value()) |
|
zobj = zlib.decompressobj() |
|
decompressed_bytes = zobj.decompress(datum_tmp.wav_file) |
|
|
|
|
|
waveform = np.frombuffer(decompressed_bytes, dtype=np.float32) |
|
except: |
|
tyu += 1 |
|
pass |
|
tyu += 1 |
|
last_index = index |
|
index = random.randint(0, len(self.keys) - 1) |
|
if tyu > 1: |
|
print('error') |
|
index = last_index |
|
flag = 0 |
|
val = 623787092.84794 |
|
while (flag == 0): |
|
id_, k = self.keys[index] |
|
with self.mos_txn_env.begin(write=False) as txn: |
|
cursor = txn.cursor() |
|
try: |
|
if cursor.set_key(k): |
|
datum_mos = Datum_lmdb_mos() |
|
datum_mos.ParseFromString(cursor.value()) |
|
mos = datum_mos.mos |
|
else: |
|
mos = -1.0 |
|
except : |
|
mos = -1.0 |
|
if 'pixa_' in k.decode() or 'ifly_' in k.decode(): |
|
mos = 5.0 |
|
if np.random.rand() < math.exp(5.0 * mos) / val: |
|
flag = 1 |
|
last_index = index |
|
index = random.randint(0, len(self.keys) - 1) |
|
index = last_index |
|
caption_original = datum_tmp.caption_original |
|
try: |
|
caption_generated = datum_tmp.caption_generated[0] |
|
except: |
|
caption_generated = 'None' |
|
assert len(caption_generated) > 1 |
|
caption_original = caption_original.lower() |
|
caption_generated = caption_generated.lower() |
|
caption = 'music' |
|
if ("msd_" in k.decode()): |
|
caption = caption_generated if caption_original == "none" else caption_original |
|
elif ("audioset_" in k.decode()): |
|
caption = caption_generated if caption_generated != "none" else caption_original |
|
elif ("mtt_" in k.decode()): |
|
caption = caption_generated if caption_original == "none" else caption_original |
|
elif ("fma_" in k.decode()): |
|
caption = caption_generated if caption_generated != "none" else caption_original |
|
elif ("pixa_" in k.decode() or "ifly_" in k.decode()): |
|
caption = caption_generated if caption_generated != "none" else caption_original |
|
else: |
|
caption = caption_original |
|
prefix = 'medium quality' |
|
if ("pixa_" in k.decode() or "ifly_" in k.decode()): |
|
if caption == 'none': |
|
prefix = 'high quality' |
|
caption = '' |
|
else: |
|
prefix = 'high quality' |
|
mos = 5.00 |
|
else: |
|
mos = float(mos) |
|
if mos > 3.55 and mos < 4.05: |
|
prefix = "medium quality" |
|
elif mos >= 4.05: |
|
prefix = "high quality" |
|
elif mos <= 3.55: |
|
prefix = "low quality" |
|
else: |
|
print(f'mos score for key : {k.decode()} miss, please check') |
|
|
|
|
|
caption = prefix + ', ' + caption |
|
miu = 3.80 |
|
sigma = 0.20 |
|
if miu - 2 * sigma <= mos < miu - sigma: |
|
vq_mos = 2 |
|
elif miu - sigma <= mos < miu + sigma: |
|
vq_mos = 3 |
|
elif miu + sigma <= mos < miu + 2 * sigma: |
|
vq_mos = 4 |
|
elif mos >= miu + 2 * sigma: |
|
vq_mos = 5 |
|
else: |
|
vq_mos = 1 |
|
""" |
|
tags = datum_tmp.tags.decode() |
|
caption_writing = datum_tmp.caption_writing.decode() |
|
caption_paraphrase = datum_tmp.caption_paraphrase.decode() |
|
caption_attribute_prediction = datum_tmp.caption_attribute_prediction.decode() |
|
caption_summary = datum_tmp.caption_summary.decode() |
|
""" |
|
( |
|
log_mel_spec, |
|
stft, |
|
waveform, |
|
random_start, |
|
) = self.read_audio_file(waveform, k.decode()) |
|
fname = self.keys[index] |
|
|
|
|
|
waveform = torch.FloatTensor(waveform) |
|
label_vector = torch.FloatTensor(np.zeros(0, dtype=np.float32)) |
|
|
|
|
|
|
|
|
|
return ( |
|
fname, |
|
waveform, |
|
stft, |
|
log_mel_spec, |
|
label_vector, |
|
random_start, |
|
caption, |
|
vq_mos |
|
) |
|
|
|
def build_setting_parameters(self): |
|
|
|
self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"] |
|
|
|
|
|
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] |
|
self.hopsize = self.config["preprocessing"]["stft"]["hop_length"] |
|
self.duration = self.config["preprocessing"]["audio"]["duration"] |
|
self.target_length = int(self.duration * self.sampling_rate / self.hopsize) |
|
|
|
self.mixup = self.config["augmentation"]["mixup"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_dsp(self): |
|
self.mel_basis = {} |
|
self.hann_window = {} |
|
|
|
self.filter_length = self.config["preprocessing"]["stft"]["filter_length"] |
|
self.hop_length = self.config["preprocessing"]["stft"]["hop_length"] |
|
self.win_length = self.config["preprocessing"]["stft"]["win_length"] |
|
self.n_mel = self.config["preprocessing"]["mel"]["n_mel_channels"] |
|
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] |
|
self.mel_fmin = self.config["preprocessing"]["mel"]["mel_fmin"] |
|
self.mel_fmax = self.config["preprocessing"]["mel"]["mel_fmax"] |
|
|
|
self.STFT = Audio.stft.TacotronSTFT( |
|
self.config["preprocessing"]["stft"]["filter_length"], |
|
self.config["preprocessing"]["stft"]["hop_length"], |
|
self.config["preprocessing"]["stft"]["win_length"], |
|
self.config["preprocessing"]["mel"]["n_mel_channels"], |
|
self.config["preprocessing"]["audio"]["sampling_rate"], |
|
self.config["preprocessing"]["mel"]["mel_fmin"], |
|
self.config["preprocessing"]["mel"]["mel_fmax"], |
|
) |
|
|
|
def resample(self, waveform, sr): |
|
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) |
|
|
|
return waveform |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_wav(self, waveform): |
|
waveform = waveform - np.mean(waveform) |
|
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) |
|
return waveform * 0.5 |
|
|
|
def random_segment_wav(self, waveform, target_length): |
|
waveform = torch.tensor(waveform) |
|
waveform = waveform.unsqueeze(0) |
|
waveform_length = waveform.shape[-1] |
|
|
|
if waveform_length < 100: |
|
waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform_length)) |
|
|
|
|
|
if (waveform_length - target_length) <= 0: |
|
return waveform, 0 |
|
|
|
for i in range(10): |
|
random_start = int(self.random_uniform(0, waveform_length - target_length)) |
|
if torch.max( |
|
torch.abs(waveform[:, random_start : random_start + target_length]) |
|
> 1e-4 |
|
): |
|
break |
|
|
|
return waveform[:, random_start : random_start + target_length], random_start |
|
|
|
def pad_wav(self, waveform, target_length): |
|
|
|
|
|
|
|
waveform_length = waveform.shape[-1] |
|
|
|
if waveform_length < 100: |
|
waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform_length)) |
|
|
|
if waveform_length == target_length: |
|
return waveform |
|
|
|
|
|
temp_wav = np.zeros((1, target_length), dtype=np.float32) |
|
if self.pad_wav_start_sample is None: |
|
rand_start = int(self.random_uniform(0, target_length - waveform_length)) |
|
else: |
|
rand_start = 0 |
|
|
|
temp_wav[:, rand_start : rand_start + waveform_length] = waveform |
|
return temp_wav |
|
|
|
def trim_wav(self, waveform): |
|
if np.max(np.abs(waveform)) < 0.0001: |
|
return waveform |
|
|
|
def detect_leading_silence(waveform, threshold=0.0001): |
|
chunk_size = 1000 |
|
waveform_length = waveform.shape[0] |
|
start = 0 |
|
while start + chunk_size < waveform_length: |
|
if np.max(np.abs(waveform[start : start + chunk_size])) < threshold: |
|
start += chunk_size |
|
else: |
|
break |
|
return start |
|
|
|
def detect_ending_silence(waveform, threshold=0.0001): |
|
chunk_size = 1000 |
|
waveform_length = waveform.shape[0] |
|
start = waveform_length |
|
while start - chunk_size > 0: |
|
if np.max(np.abs(waveform[start - chunk_size : start])) < threshold: |
|
start -= chunk_size |
|
else: |
|
break |
|
if start == waveform_length: |
|
return start |
|
else: |
|
return start + chunk_size |
|
|
|
start = detect_leading_silence(waveform) |
|
end = detect_ending_silence(waveform) |
|
|
|
return waveform[start:end] |
|
|
|
def read_wav_file(self, file, k): |
|
|
|
|
|
|
|
|
|
|
|
waveform = file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sr = 16000 |
|
waveform, random_start = self.random_segment_wav( |
|
waveform, target_length=int(sr * self.duration) |
|
) |
|
waveform = self.resample(waveform, sr) |
|
|
|
|
|
waveform = waveform.numpy()[0, ...] |
|
|
|
waveform = self.normalize_wav(waveform) |
|
|
|
if self.trim_wav: |
|
waveform = self.trim_wav(waveform) |
|
|
|
waveform = waveform[None, ...] |
|
waveform = self.pad_wav( |
|
waveform, target_length=int(self.sampling_rate * self.duration) |
|
) |
|
return waveform, random_start |
|
|
|
def mix_two_waveforms(self, waveform1, waveform2): |
|
mix_lambda = np.random.beta(5, 5) |
|
mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2 |
|
return self.normalize_wav(mix_waveform), mix_lambda |
|
|
|
def read_audio_file(self, file, k): |
|
|
|
|
|
|
|
|
|
waveform, random_start = self.read_wav_file(file, k) |
|
|
|
|
|
log_mel_spec, stft = self.wav_feature_extraction(waveform) |
|
|
|
return log_mel_spec, stft, waveform, random_start |
|
|
|
def mel_spectrogram_train(self, y): |
|
if torch.min(y) < -1.0: |
|
print("train min value is ", torch.min(y)) |
|
if torch.max(y) > 1.0: |
|
print("train max value is ", torch.max(y)) |
|
|
|
|
|
if self.mel_fmax not in self.mel_basis: |
|
|
|
|
|
mel = librosa_mel_fn( |
|
sr=self.sampling_rate, |
|
n_fft=self.filter_length, |
|
n_mels=self.n_mel, |
|
fmin=self.mel_fmin, |
|
fmax=self.mel_fmax, |
|
) |
|
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = ( |
|
torch.from_numpy(mel).float().to(y.device) |
|
) |
|
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to( |
|
y.device |
|
) |
|
|
|
y = torch.nn.functional.pad( |
|
y.unsqueeze(1), |
|
( |
|
int((self.filter_length - self.hop_length) / 2), |
|
int((self.filter_length - self.hop_length) / 2), |
|
), |
|
mode="reflect", |
|
) |
|
|
|
y = y.squeeze(1) |
|
|
|
|
|
stft_spec = torch.stft( |
|
y, |
|
self.filter_length, |
|
hop_length=self.hop_length, |
|
win_length=self.win_length, |
|
window=self.hann_window[str(y.device)], |
|
center=False, |
|
pad_mode="reflect", |
|
normalized=False, |
|
onesided=True, |
|
return_complex=True, |
|
) |
|
|
|
stft_spec = torch.abs(stft_spec) |
|
|
|
mel = spectral_normalize_torch( |
|
torch.matmul( |
|
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], stft_spec |
|
) |
|
) |
|
|
|
return mel[0], stft_spec[0] |
|
|
|
|
|
def wav_feature_extraction(self, waveform): |
|
waveform = waveform[0, ...] |
|
waveform = torch.FloatTensor(waveform) |
|
|
|
|
|
log_mel_spec, stft = self.mel_spectrogram_train(waveform.unsqueeze(0)) |
|
|
|
log_mel_spec = torch.FloatTensor(log_mel_spec.T) |
|
stft = torch.FloatTensor(stft.T) |
|
|
|
log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) |
|
return log_mel_spec, stft |
|
|
|
def pad_spec(self, log_mel_spec): |
|
n_frames = log_mel_spec.shape[0] |
|
p = self.target_length - n_frames |
|
|
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
log_mel_spec = m(log_mel_spec) |
|
elif p < 0: |
|
log_mel_spec = log_mel_spec[0 : self.target_length, :] |
|
|
|
if log_mel_spec.size(-1) % 2 != 0: |
|
log_mel_spec = log_mel_spec[..., :-1] |
|
|
|
return log_mel_spec |
|
|
|
def _read_datum_caption(self, datum): |
|
caption_keys = [x for x in datum.keys() if ("caption" in x)] |
|
random_index = torch.randint(0, len(caption_keys), (1,))[0].item() |
|
return datum[caption_keys[random_index]] |
|
|
|
def _is_contain_caption(self, datum): |
|
caption_keys = [x for x in datum.keys() if ("caption" in x)] |
|
return len(caption_keys) > 0 |
|
|
|
def label_indices_to_text(self, datum, label_indices): |
|
if self._is_contain_caption(datum): |
|
return self._read_datum_caption(datum) |
|
elif "label" in datum.keys(): |
|
name_indices = torch.where(label_indices > 0.1)[0] |
|
|
|
description_header = "" |
|
labels = "" |
|
for id, each in enumerate(name_indices): |
|
if id == len(name_indices) - 1: |
|
labels += "%s." % self.num2label[int(each)] |
|
else: |
|
labels += "%s, " % self.num2label[int(each)] |
|
return description_header + labels |
|
else: |
|
return "" |
|
|
|
def random_uniform(self, start, end): |
|
val = torch.rand(1).item() |
|
return start + (end - start) * val |
|
|
|
def frequency_masking(self, log_mel_spec, freqm): |
|
bs, freq, tsteps = log_mel_spec.size() |
|
mask_len = int(self.random_uniform(freqm // 8, freqm)) |
|
mask_start = int(self.random_uniform(start=0, end=freq - mask_len)) |
|
log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0 |
|
return log_mel_spec |
|
|
|
def time_masking(self, log_mel_spec, timem): |
|
bs, freq, tsteps = log_mel_spec.size() |
|
mask_len = int(self.random_uniform(timem // 8, timem)) |
|
mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len)) |
|
log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0 |
|
return log_mel_spec |
|
|
|
|
|
class AudioDataset_infer(Dataset): |
|
def __init__( |
|
self, |
|
config, |
|
caption_list, |
|
lock=True |
|
): |
|
self.config = config |
|
|
|
""" |
|
Dataset that manage caption writings |
|
""" |
|
self.captions = [] |
|
with open(caption_list, 'r') as f: |
|
for _ ,line in enumerate(f): |
|
key = line.strip() |
|
self.captions.append(key.split()[0]) |
|
self.duration = self.duration = self.config["preprocessing"]["audio"]["duration"] |
|
self.sampling_rate = self.config["variables"]["sampling_rate"] |
|
self.target_length = int(self.sampling_rate * self.duration) |
|
self.waveform = torch.zeros((1, self.target_length)) |
|
|
|
def __getitem__(self, index): |
|
|
|
fname = [f"sample_{index}"] |
|
data = { |
|
"text": [self.captions[index]], |
|
"fname": fname, |
|
|
|
"waveform": "", |
|
|
|
"stft": "", |
|
|
|
"log_mel_spec": "", |
|
"duration": self.duration, |
|
"sampling_rate": self.sampling_rate, |
|
"random_start_sample_in_original_audio_file": 0, |
|
"label_vector": torch.FloatTensor(np.zeros(0, dtype=np.float32)), |
|
"mos":mos |
|
} |
|
|
|
if data["text"] is None: |
|
print("Warning: The model return None on key text", fname) |
|
data["text"] = "" |
|
|
|
return data |
|
|
|
def __len__(self): |
|
return len(self.captions) |
|
|
|
if __name__ == "__main__": |
|
import torch |
|
from tqdm import tqdm |
|
from pytorch_lightning import seed_everything |
|
from torch.utils.data import DataLoader |
|
|
|
seed_everything(0) |
|
|
|
def write_json(my_dict, fname): |
|
|
|
json_str = json.dumps(my_dict) |
|
with open(fname, "w") as json_file: |
|
json_file.write(json_str) |
|
|
|
def load_json(fname): |
|
with open(fname, "r") as f: |
|
data = json.load(f) |
|
return data |
|
|
|
config = yaml.load( |
|
open( |
|
"/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/config/vae_48k_256/ds_8_kl_1.0_ch_16.yaml", |
|
"r", |
|
), |
|
Loader=yaml.FullLoader, |
|
) |
|
|
|
add_ons = config["data"]["dataloader_add_ons"] |
|
|
|
|
|
dataset = AudioDataset( |
|
config=config, split="train", waveform_only=False, add_ons=add_ons |
|
) |
|
|
|
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|