Sebas / main.py
Mudrock's picture
Upload 18 files
530a7d1
# Ke Chen
# [email protected]
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
# The Main Script
import os
# this is to avoid the sdr calculation from occupying all cpus
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "6"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "6"
import sys
import librosa
import numpy as np
import argparse
import logging
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from utils import collect_fn, dump_config, create_folder, prepprocess_audio
import musdb
from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
from data_processor import LGSPDataset, MusdbDataset
import config
import htsat_config
from models.htsat import HTSAT_Swin_Transformer
from sed_model import SEDWrapper
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from htsat_utils import process_idc
import warnings
warnings.filterwarnings("ignore")
class data_prep(pl.LightningDataModule):
def __init__(self, train_dataset, eval_dataset, device_num, config):
super().__init__()
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.device_num = device_num
self.config = config
def train_dataloader(self):
train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
train_loader = DataLoader(
dataset = self.train_dataset,
num_workers = config.num_workers,
batch_size = config.batch_size // self.device_num,
shuffle = False,
sampler = train_sampler,
collate_fn = collect_fn
)
return train_loader
def val_dataloader(self):
eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
eval_loader = DataLoader(
dataset = self.eval_dataset,
num_workers = config.num_workers,
batch_size = config.batch_size // self.device_num,
shuffle = False,
sampler = eval_sampler,
collate_fn = collect_fn
)
return eval_loader
def test_dataloader(self):
test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
test_loader = DataLoader(
dataset = self.eval_dataset,
num_workers = config.num_workers,
batch_size = config.batch_size // self.device_num,
shuffle = False,
sampler = test_sampler,
collate_fn = collect_fn
)
return test_loader
def save_idc():
train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5")
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy")
process_idc(eval_index_path, config.classes_num, "eval_idc.npy")
# Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz
def process_musdb():
# use musdb as testset
test_data = musdb.DB(
root = config.musdb_path,
download = False,
subsets = "test",
is_wav = True
)
print(len(test_data.tracks))
mus_tracks = []
# in musdb, all fs is the same (44100)
orig_fs = test_data.tracks[0].rate
print(orig_fs)
for track in test_data.tracks:
temp = {}
mixture = prepprocess_audio(
track.audio,
orig_fs, config.sample_rate,
config.test_type
)
temp["mixture" ]= mixture
for dickey in config.test_key:
source = prepprocess_audio(
track.targets[dickey].audio,
orig_fs, config.sample_rate,
config.test_type
)
temp[dickey] = source
print(track.audio.shape, len(temp.keys()), temp["mixture"].shape)
mus_tracks.append(temp)
print(len(mus_tracks))
# save the file to npy
np.save("musdb-32000fs.npy", mus_tracks)
# weight average will perform in the given folder
# It will output one model checkpoint, which avergas the weight of all models in the folder
def weight_average():
model_ckpt = []
model_files = os.listdir(config.wa_model_folder)
wa_ckpt = {
"state_dict": {}
}
for model_file in model_files:
model_file = os.path.join(config.esm_model_folder, model_file)
model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"])
keys = model_ckpt[0].keys()
for key in keys:
model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt])
model_ckpt_key = torch.mean(model_ckpt_key, dim = 0)
assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape
wa_ckpt["state_dict"][key] = model_ckpt_key
torch.save(wa_ckpt, config.wa_model_path)
# use the model to quickly separate a track given a query
# it requires four variables in config.py:
# inference_file: the track you want to separate
# inference_query: a **folder** containing all samples from the same source
# test_key: ["name"] indicate the source name (just a name for final output, no other functions)
# wave_output_path: the output folder
# make sure the query folder contain the samples from the same source
# each time, the model is able to separate one source from the track
# if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
def inference():
# set exp settings
device_name = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda")
assert config.test_key is not None, "there should be a separate key"
create_folder(config.wave_output_path)
test_track, fs = librosa.load(config.inference_file, sr = None)
test_track = test_track[:,None]
print(test_track.shape)
print(fs)
# convert the track into 32000 Hz sample rate
test_track = prepprocess_audio(
test_track,
fs, config.sample_rate,
config.test_type
)
test_tracks = []
temp = [test_track]
for dickey in config.test_key:
temp.append(test_track)
temp = np.array(temp)
test_tracks.append(temp)
dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it
loader = DataLoader(
dataset = dataset,
num_workers = 1,
batch_size = 1,
shuffle = False
)
# obtain the samples for query
queries = []
for query_file in os.listdir(config.inference_query):
f_path = os.path.join(config.inference_query, query_file)
if query_file.endswith(".wav"):
temp_q, fs = librosa.load(f_path, sr = None)
temp_q = temp_q[:, None]
temp_q = prepprocess_audio(
temp_q,
fs, config.sample_rate,
config.test_type
)
temp = [temp_q]
for dickey in config.test_key:
temp.append(temp_q)
temp = np.array(temp)
queries.append(temp)
assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
sed_model = HTSAT_Swin_Transformer(
spec_size=htsat_config.htsat_spec_size,
patch_size=htsat_config.htsat_patch_size,
in_chans=1,
num_classes=htsat_config.classes_num,
window_size=htsat_config.htsat_window_size,
config = htsat_config,
depths = htsat_config.htsat_depth,
embed_dim = htsat_config.htsat_dim,
patch_stride=htsat_config.htsat_stride,
num_heads=htsat_config.htsat_num_head
)
at_model = SEDWrapper(
sed_model = sed_model,
config = htsat_config,
dataset = None
)
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
at_model.load_state_dict(ckpt["state_dict"])
trainer = pl.Trainer(
gpus = 1
)
avg_at = None
# obtain the latent embedding as query
if config.infer_type == "mean":
avg_dataset = MusdbDataset(tracks = queries)
avg_loader = DataLoader(
dataset = avg_dataset,
num_workers = 1,
batch_size = 1,
shuffle = False
)
at_wrapper = AutoTaggingWarpper(
at_model = at_model,
config = config,
target_keys = config.test_key
)
trainer.test(at_wrapper, test_dataloaders = avg_loader)
avg_at = at_wrapper.avg_at
# import seapration model
model = ZeroShotASP(
channels = 1, config = config,
at_model = at_model,
dataset = dataset
)
# resume checkpoint
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
model.load_state_dict(ckpt["state_dict"], strict= False)
exp_model = SeparatorModel(
model = model,
config = config,
target_keys = config.test_key,
avg_at = avg_at,
using_wiener = False,
calc_sdr = False,
output_wav = True
)
trainer.test(exp_model, test_dataloaders = loader)
# test the separation model, mainly in musdb
def test():
# set exp settings
device_name = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda")
assert config.test_key is not None, "there should be a separate key"
create_folder(config.wave_output_path)
# use musdb as testset
test_data = np.load(config.testset_path, allow_pickle = True)
print(len(test_data))
mus_tracks = []
# in musdb, all fs is the same (44100)
# load the dataset
for track in test_data:
temp = []
mixture = track["mixture"]
temp.append(mixture)
for dickey in config.test_key:
source = track[dickey]
temp.append(source)
temp = np.array(temp)
print(temp.shape)
mus_tracks.append(temp)
print(len(mus_tracks))
dataset = MusdbDataset(tracks = mus_tracks)
loader = DataLoader(
dataset = dataset,
num_workers = 1,
batch_size = 1,
shuffle = False
)
assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
sed_model = HTSAT_Swin_Transformer(
spec_size=htsat_config.htsat_spec_size,
patch_size=htsat_config.htsat_patch_size,
in_chans=1,
num_classes=htsat_config.classes_num,
window_size=htsat_config.htsat_window_size,
config = htsat_config,
depths = htsat_config.htsat_depth,
embed_dim = htsat_config.htsat_dim,
patch_stride=htsat_config.htsat_stride,
num_heads=htsat_config.htsat_num_head
)
at_model = SEDWrapper(
sed_model = sed_model,
config = htsat_config,
dataset = None
)
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
at_model.load_state_dict(ckpt["state_dict"])
trainer = pl.Trainer(
gpus = 1
)
avg_at = None
# obtain the query of four stems from the training set
if config.infer_type == "mean":
avg_data = np.load(config.testavg_path, allow_pickle = True)[:90]
print(len(avg_data))
avgmus_tracks = []
# in musdb, all fs is the same (44100)
# load the dataset
for track in avg_data:
temp = []
mixture = track["mixture"]
temp.append(mixture)
for dickey in config.test_key:
source = track[dickey]
temp.append(source)
temp = np.array(temp)
print(temp.shape)
avgmus_tracks.append(temp)
print(len(avgmus_tracks))
avg_dataset = MusdbDataset(tracks = avgmus_tracks)
avg_loader = DataLoader(
dataset = avg_dataset,
num_workers = 1,
batch_size = 1,
shuffle = False
)
at_wrapper = AutoTaggingWarpper(
at_model = at_model,
config = config,
target_keys = config.test_key
)
trainer.test(at_wrapper, test_dataloaders = avg_loader)
avg_at = at_wrapper.avg_at
model = ZeroShotASP(
channels = 1, config = config,
at_model = at_model,
dataset = dataset
)
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
model.load_state_dict(ckpt["state_dict"], strict= False)
exp_model = SeparatorModel(
model = model,
config = config,
target_keys = config.test_key,
avg_at = avg_at,
using_wiener = config.using_wiener
)
trainer.test(exp_model, test_dataloaders = loader)
def train():
# set exp settings
# device_name = "cuda" if torch.cuda.is_available() else "cpu"
# device = torch.device("cuda")
device_num = torch.cuda.device_count()
print("each batch size:", config.batch_size // device_num)
train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5")
train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True)
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True)
# set exp folder
exp_dir = os.path.join(config.workspace, "results", config.exp_name)
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
if not config.debug:
create_folder(os.path.join(config.workspace, "results"))
create_folder(exp_dir)
create_folder(checkpoint_dir)
dump_config(config, os.path.join(exp_dir, config.exp_name), False)
# load data
# import dataset LGSPDataset (latent general source separation) and sampler
dataset = LGSPDataset(
index_path = train_index_path,
idc = train_idc,
config = config,
factor = 0.05,
eval_mode = False
)
eval_dataset = LGSPDataset(
index_path = eval_index_path,
idc = eval_idc,
config = config,
factor = 0.05,
eval_mode = True
)
audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config)
checkpoint_callback = ModelCheckpoint(
monitor = "mixture_sdr",
filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}',
save_top_k = 10,
mode = "max"
)
# infer at model
sed_model = HTSAT_Swin_Transformer(
spec_size=htsat_config.htsat_spec_size,
patch_size=htsat_config.htsat_patch_size,
in_chans=1,
num_classes=htsat_config.classes_num,
window_size=htsat_config.htsat_window_size,
config = htsat_config,
depths = htsat_config.htsat_depth,
embed_dim = htsat_config.htsat_dim,
patch_stride=htsat_config.htsat_stride,
num_heads=htsat_config.htsat_num_head
)
at_model = SEDWrapper(
sed_model = sed_model,
config = htsat_config,
dataset = None
)
# load the checkpoint
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
at_model.load_state_dict(ckpt["state_dict"])
trainer = pl.Trainer(
deterministic=True,
default_root_dir = checkpoint_dir,
gpus = device_num,
val_check_interval = 0.2,
# check_val_every_n_epoch = 1,
max_epochs = config.max_epoch,
auto_lr_find = True,
sync_batchnorm = True,
callbacks = [checkpoint_callback],
accelerator = "ddp" if device_num > 1 else None,
resume_from_checkpoint = None, #config.resume_checkpoint,
replace_sampler_ddp = False,
gradient_clip_val=1.0,
num_sanity_val_steps = 0,
)
model = ZeroShotASP(
channels = 1, config = config,
at_model = at_model,
dataset = dataset
)
if config.resume_checkpoint is not None:
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
model.load_state_dict(ckpt["state_dict"])
# trainer.test(model, datamodule = audioset_data)
trainer.fit(model, audioset_data)
def main():
parser = argparse.ArgumentParser(description="latent genreal source separation parser")
subparsers = parser.add_subparsers(dest = "mode")
parser_train = subparsers.add_parser("train")
parser_test = subparsers.add_parser("test")
parser_musdb = subparsers.add_parser("musdb_process")
parser_saveidc = subparsers.add_parser("save_idc")
parser_wa = subparsers.add_parser("weight_average")
parser_infer = subparsers.add_parser("inference")
args = parser.parse_args()
# default settings
logging.basicConfig(level=logging.INFO)
pl.utilities.seed.seed_everything(seed = config.random_seed)
if args.mode == "train":
train()
elif args.mode == "test":
test()
elif args.mode == "musdb_process":
process_musdb()
elif args.mode == "weight_average":
weight_average()
elif args.mode == "save_idc":
save_idc()
elif args.mode == "inference":
inference()
else:
raise Exception("Error Mode!")
if __name__ == '__main__':
main()