|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
os.environ["OMP_NUM_THREADS"] = "4" |
|
os.environ["OPENBLAS_NUM_THREADS"] = "4" |
|
os.environ["MKL_NUM_THREADS"] = "6" |
|
os.environ["VECLIB_MAXIMUM_THREADS"] = "4" |
|
os.environ["NUMEXPR_NUM_THREADS"] = "6" |
|
|
|
import sys |
|
import librosa |
|
import numpy as np |
|
import argparse |
|
import logging |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from utils import collect_fn, dump_config, create_folder, prepprocess_audio |
|
import musdb |
|
|
|
from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper |
|
from data_processor import LGSPDataset, MusdbDataset |
|
import config |
|
import htsat_config |
|
from models.htsat import HTSAT_Swin_Transformer |
|
from sed_model import SEDWrapper |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
from htsat_utils import process_idc |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
class data_prep(pl.LightningDataModule): |
|
def __init__(self, train_dataset, eval_dataset, device_num, config): |
|
super().__init__() |
|
self.train_dataset = train_dataset |
|
self.eval_dataset = eval_dataset |
|
self.device_num = device_num |
|
self.config = config |
|
|
|
def train_dataloader(self): |
|
train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None |
|
train_loader = DataLoader( |
|
dataset = self.train_dataset, |
|
num_workers = config.num_workers, |
|
batch_size = config.batch_size // self.device_num, |
|
shuffle = False, |
|
sampler = train_sampler, |
|
collate_fn = collect_fn |
|
) |
|
return train_loader |
|
def val_dataloader(self): |
|
eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None |
|
eval_loader = DataLoader( |
|
dataset = self.eval_dataset, |
|
num_workers = config.num_workers, |
|
batch_size = config.batch_size // self.device_num, |
|
shuffle = False, |
|
sampler = eval_sampler, |
|
collate_fn = collect_fn |
|
) |
|
return eval_loader |
|
def test_dataloader(self): |
|
test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None |
|
test_loader = DataLoader( |
|
dataset = self.eval_dataset, |
|
num_workers = config.num_workers, |
|
batch_size = config.batch_size // self.device_num, |
|
shuffle = False, |
|
sampler = test_sampler, |
|
collate_fn = collect_fn |
|
) |
|
return test_loader |
|
|
|
def save_idc(): |
|
train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5") |
|
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5") |
|
process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy") |
|
process_idc(eval_index_path, config.classes_num, "eval_idc.npy") |
|
|
|
|
|
def process_musdb(): |
|
|
|
test_data = musdb.DB( |
|
root = config.musdb_path, |
|
download = False, |
|
subsets = "test", |
|
is_wav = True |
|
) |
|
print(len(test_data.tracks)) |
|
mus_tracks = [] |
|
|
|
orig_fs = test_data.tracks[0].rate |
|
print(orig_fs) |
|
for track in test_data.tracks: |
|
temp = {} |
|
mixture = prepprocess_audio( |
|
track.audio, |
|
orig_fs, config.sample_rate, |
|
config.test_type |
|
) |
|
temp["mixture" ]= mixture |
|
for dickey in config.test_key: |
|
source = prepprocess_audio( |
|
track.targets[dickey].audio, |
|
orig_fs, config.sample_rate, |
|
config.test_type |
|
) |
|
temp[dickey] = source |
|
print(track.audio.shape, len(temp.keys()), temp["mixture"].shape) |
|
mus_tracks.append(temp) |
|
print(len(mus_tracks)) |
|
|
|
np.save("musdb-32000fs.npy", mus_tracks) |
|
|
|
|
|
|
|
def weight_average(): |
|
model_ckpt = [] |
|
model_files = os.listdir(config.wa_model_folder) |
|
wa_ckpt = { |
|
"state_dict": {} |
|
} |
|
|
|
for model_file in model_files: |
|
model_file = os.path.join(config.esm_model_folder, model_file) |
|
model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"]) |
|
keys = model_ckpt[0].keys() |
|
for key in keys: |
|
model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt]) |
|
model_ckpt_key = torch.mean(model_ckpt_key, dim = 0) |
|
assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape |
|
wa_ckpt["state_dict"][key] = model_ckpt_key |
|
torch.save(wa_ckpt, config.wa_model_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(): |
|
|
|
device_name = "cuda" if torch.cuda.is_available() else "cpu" |
|
device = torch.device("cuda") |
|
assert config.test_key is not None, "there should be a separate key" |
|
create_folder(config.wave_output_path) |
|
test_track, fs = librosa.load(config.inference_file, sr = None) |
|
test_track = test_track[:,None] |
|
print(test_track.shape) |
|
print(fs) |
|
|
|
test_track = prepprocess_audio( |
|
test_track, |
|
fs, config.sample_rate, |
|
config.test_type |
|
) |
|
test_tracks = [] |
|
temp = [test_track] |
|
for dickey in config.test_key: |
|
temp.append(test_track) |
|
temp = np.array(temp) |
|
test_tracks.append(temp) |
|
dataset = MusdbDataset(tracks = test_tracks) |
|
loader = DataLoader( |
|
dataset = dataset, |
|
num_workers = 1, |
|
batch_size = 1, |
|
shuffle = False |
|
) |
|
|
|
queries = [] |
|
for query_file in os.listdir(config.inference_query): |
|
f_path = os.path.join(config.inference_query, query_file) |
|
if query_file.endswith(".wav"): |
|
temp_q, fs = librosa.load(f_path, sr = None) |
|
temp_q = temp_q[:, None] |
|
temp_q = prepprocess_audio( |
|
temp_q, |
|
fs, config.sample_rate, |
|
config.test_type |
|
) |
|
temp = [temp_q] |
|
for dickey in config.test_key: |
|
temp.append(temp_q) |
|
temp = np.array(temp) |
|
queries.append(temp) |
|
|
|
assert config.resume_checkpoint is not None, "there should be a saved model when inferring" |
|
|
|
sed_model = HTSAT_Swin_Transformer( |
|
spec_size=htsat_config.htsat_spec_size, |
|
patch_size=htsat_config.htsat_patch_size, |
|
in_chans=1, |
|
num_classes=htsat_config.classes_num, |
|
window_size=htsat_config.htsat_window_size, |
|
config = htsat_config, |
|
depths = htsat_config.htsat_depth, |
|
embed_dim = htsat_config.htsat_dim, |
|
patch_stride=htsat_config.htsat_stride, |
|
num_heads=htsat_config.htsat_num_head |
|
) |
|
at_model = SEDWrapper( |
|
sed_model = sed_model, |
|
config = htsat_config, |
|
dataset = None |
|
) |
|
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") |
|
at_model.load_state_dict(ckpt["state_dict"]) |
|
|
|
trainer = pl.Trainer( |
|
gpus = 1 |
|
) |
|
avg_at = None |
|
|
|
if config.infer_type == "mean": |
|
avg_dataset = MusdbDataset(tracks = queries) |
|
avg_loader = DataLoader( |
|
dataset = avg_dataset, |
|
num_workers = 1, |
|
batch_size = 1, |
|
shuffle = False |
|
) |
|
at_wrapper = AutoTaggingWarpper( |
|
at_model = at_model, |
|
config = config, |
|
target_keys = config.test_key |
|
) |
|
trainer.test(at_wrapper, test_dataloaders = avg_loader) |
|
avg_at = at_wrapper.avg_at |
|
|
|
|
|
model = ZeroShotASP( |
|
channels = 1, config = config, |
|
at_model = at_model, |
|
dataset = dataset |
|
) |
|
|
|
ckpt = torch.load(config.resume_checkpoint, map_location="cpu") |
|
model.load_state_dict(ckpt["state_dict"], strict= False) |
|
exp_model = SeparatorModel( |
|
model = model, |
|
config = config, |
|
target_keys = config.test_key, |
|
avg_at = avg_at, |
|
using_wiener = False, |
|
calc_sdr = False, |
|
output_wav = True |
|
) |
|
trainer.test(exp_model, test_dataloaders = loader) |
|
|
|
|
|
def test(): |
|
|
|
device_name = "cuda" if torch.cuda.is_available() else "cpu" |
|
device = torch.device("cuda") |
|
assert config.test_key is not None, "there should be a separate key" |
|
create_folder(config.wave_output_path) |
|
|
|
test_data = np.load(config.testset_path, allow_pickle = True) |
|
print(len(test_data)) |
|
mus_tracks = [] |
|
|
|
|
|
for track in test_data: |
|
temp = [] |
|
mixture = track["mixture"] |
|
temp.append(mixture) |
|
for dickey in config.test_key: |
|
source = track[dickey] |
|
temp.append(source) |
|
temp = np.array(temp) |
|
print(temp.shape) |
|
mus_tracks.append(temp) |
|
print(len(mus_tracks)) |
|
dataset = MusdbDataset(tracks = mus_tracks) |
|
loader = DataLoader( |
|
dataset = dataset, |
|
num_workers = 1, |
|
batch_size = 1, |
|
shuffle = False |
|
) |
|
assert config.resume_checkpoint is not None, "there should be a saved model when inferring" |
|
|
|
sed_model = HTSAT_Swin_Transformer( |
|
spec_size=htsat_config.htsat_spec_size, |
|
patch_size=htsat_config.htsat_patch_size, |
|
in_chans=1, |
|
num_classes=htsat_config.classes_num, |
|
window_size=htsat_config.htsat_window_size, |
|
config = htsat_config, |
|
depths = htsat_config.htsat_depth, |
|
embed_dim = htsat_config.htsat_dim, |
|
patch_stride=htsat_config.htsat_stride, |
|
num_heads=htsat_config.htsat_num_head |
|
) |
|
at_model = SEDWrapper( |
|
sed_model = sed_model, |
|
config = htsat_config, |
|
dataset = None |
|
) |
|
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") |
|
at_model.load_state_dict(ckpt["state_dict"]) |
|
trainer = pl.Trainer( |
|
gpus = 1 |
|
) |
|
avg_at = None |
|
|
|
if config.infer_type == "mean": |
|
avg_data = np.load(config.testavg_path, allow_pickle = True)[:90] |
|
print(len(avg_data)) |
|
avgmus_tracks = [] |
|
|
|
|
|
for track in avg_data: |
|
temp = [] |
|
mixture = track["mixture"] |
|
temp.append(mixture) |
|
for dickey in config.test_key: |
|
source = track[dickey] |
|
temp.append(source) |
|
temp = np.array(temp) |
|
print(temp.shape) |
|
avgmus_tracks.append(temp) |
|
print(len(avgmus_tracks)) |
|
avg_dataset = MusdbDataset(tracks = avgmus_tracks) |
|
avg_loader = DataLoader( |
|
dataset = avg_dataset, |
|
num_workers = 1, |
|
batch_size = 1, |
|
shuffle = False |
|
) |
|
at_wrapper = AutoTaggingWarpper( |
|
at_model = at_model, |
|
config = config, |
|
target_keys = config.test_key |
|
) |
|
trainer.test(at_wrapper, test_dataloaders = avg_loader) |
|
avg_at = at_wrapper.avg_at |
|
|
|
model = ZeroShotASP( |
|
channels = 1, config = config, |
|
at_model = at_model, |
|
dataset = dataset |
|
) |
|
ckpt = torch.load(config.resume_checkpoint, map_location="cpu") |
|
model.load_state_dict(ckpt["state_dict"], strict= False) |
|
exp_model = SeparatorModel( |
|
model = model, |
|
config = config, |
|
target_keys = config.test_key, |
|
avg_at = avg_at, |
|
using_wiener = config.using_wiener |
|
) |
|
trainer.test(exp_model, test_dataloaders = loader) |
|
|
|
def train(): |
|
|
|
|
|
|
|
|
|
device_num = torch.cuda.device_count() |
|
print("each batch size:", config.batch_size // device_num) |
|
|
|
train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5") |
|
train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True) |
|
|
|
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5") |
|
eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True) |
|
|
|
|
|
exp_dir = os.path.join(config.workspace, "results", config.exp_name) |
|
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint") |
|
|
|
if not config.debug: |
|
create_folder(os.path.join(config.workspace, "results")) |
|
create_folder(exp_dir) |
|
create_folder(checkpoint_dir) |
|
dump_config(config, os.path.join(exp_dir, config.exp_name), False) |
|
|
|
|
|
|
|
dataset = LGSPDataset( |
|
index_path = train_index_path, |
|
idc = train_idc, |
|
config = config, |
|
factor = 0.05, |
|
eval_mode = False |
|
) |
|
eval_dataset = LGSPDataset( |
|
index_path = eval_index_path, |
|
idc = eval_idc, |
|
config = config, |
|
factor = 0.05, |
|
eval_mode = True |
|
) |
|
|
|
audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config) |
|
checkpoint_callback = ModelCheckpoint( |
|
monitor = "mixture_sdr", |
|
filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}', |
|
save_top_k = 10, |
|
mode = "max" |
|
) |
|
|
|
sed_model = HTSAT_Swin_Transformer( |
|
spec_size=htsat_config.htsat_spec_size, |
|
patch_size=htsat_config.htsat_patch_size, |
|
in_chans=1, |
|
num_classes=htsat_config.classes_num, |
|
window_size=htsat_config.htsat_window_size, |
|
config = htsat_config, |
|
depths = htsat_config.htsat_depth, |
|
embed_dim = htsat_config.htsat_dim, |
|
patch_stride=htsat_config.htsat_stride, |
|
num_heads=htsat_config.htsat_num_head |
|
) |
|
at_model = SEDWrapper( |
|
sed_model = sed_model, |
|
config = htsat_config, |
|
dataset = None |
|
) |
|
|
|
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") |
|
at_model.load_state_dict(ckpt["state_dict"]) |
|
|
|
trainer = pl.Trainer( |
|
deterministic=True, |
|
default_root_dir = checkpoint_dir, |
|
gpus = device_num, |
|
val_check_interval = 0.2, |
|
|
|
max_epochs = config.max_epoch, |
|
auto_lr_find = True, |
|
sync_batchnorm = True, |
|
callbacks = [checkpoint_callback], |
|
accelerator = "ddp" if device_num > 1 else None, |
|
resume_from_checkpoint = None, |
|
replace_sampler_ddp = False, |
|
gradient_clip_val=1.0, |
|
num_sanity_val_steps = 0, |
|
) |
|
model = ZeroShotASP( |
|
channels = 1, config = config, |
|
at_model = at_model, |
|
dataset = dataset |
|
) |
|
if config.resume_checkpoint is not None: |
|
ckpt = torch.load(config.resume_checkpoint, map_location="cpu") |
|
model.load_state_dict(ckpt["state_dict"]) |
|
|
|
trainer.fit(model, audioset_data) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="latent genreal source separation parser") |
|
subparsers = parser.add_subparsers(dest = "mode") |
|
parser_train = subparsers.add_parser("train") |
|
parser_test = subparsers.add_parser("test") |
|
parser_musdb = subparsers.add_parser("musdb_process") |
|
parser_saveidc = subparsers.add_parser("save_idc") |
|
parser_wa = subparsers.add_parser("weight_average") |
|
parser_infer = subparsers.add_parser("inference") |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
pl.utilities.seed.seed_everything(seed = config.random_seed) |
|
|
|
if args.mode == "train": |
|
train() |
|
elif args.mode == "test": |
|
test() |
|
elif args.mode == "musdb_process": |
|
process_musdb() |
|
elif args.mode == "weight_average": |
|
weight_average() |
|
elif args.mode == "save_idc": |
|
save_idc() |
|
elif args.mode == "inference": |
|
inference() |
|
else: |
|
raise Exception("Error Mode!") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|