import os import random from typing import Union import soundfile as sf import torch import yaml import json import argparse import numpy as np import pandas as pd from tqdm import tqdm from pprint import pprint from scipy.io import wavfile import warnings import torchaudio warnings.filterwarnings("ignore") import look2hear.models import look2hear.datas from look2hear.metrics import MetricsTracker from look2hear.utils import tensors_to_device, RichProgressBarTheme, MyMetricsTextColumn, BatchesProcessedColumn from rich.progress import ( BarColumn, Progress, TextColumn, TimeRemainingColumn, TransferSpeedColumn, ) parser = argparse.ArgumentParser() parser.add_argument("--conf_dir", default="local/mixit_conf.yml", help="Full path to save best validation model") compute_metrics = ["si_sdr", "sdr"] os.environ['CUDA_VISIBLE_DEVICES'] = "8" def main(config): metricscolumn = MyMetricsTextColumn(style=RichProgressBarTheme.metrics) progress = Progress( TextColumn("[bold blue]Testing", justify="right"), BarColumn(bar_width=None), "•", BatchesProcessedColumn(style=RichProgressBarTheme.batch_progress), "•", TransferSpeedColumn(), "•", TimeRemainingColumn(), "•", metricscolumn ) # import pdb; pdb.set_trace() config["train_conf"]["main_args"]["exp_dir"] = os.path.join( os.getcwd(), "Experiments", "checkpoint", config["train_conf"]["exp"]["exp_name"] ) model_path = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "best_model.pth") # import pdb; pdb.set_trace() # conf["train_conf"]["masknet"].update({"n_src": 2}) model = getattr(look2hear.models, config["train_conf"]["audionet"]["audionet_name"]).from_pretrain( model_path, sample_rate=config["train_conf"]["datamodule"]["data_config"]["sample_rate"], **config["train_conf"]["audionet"]["audionet_config"], ) if config["train_conf"]["training"]["gpus"]: device = "cuda" model.to(device) model_device = next(model.parameters()).device datamodule: object = getattr(look2hear.datas, config["train_conf"]["datamodule"]["data_name"])( **config["train_conf"]["datamodule"]["data_config"] ) datamodule.setup() _, _ , test_set = datamodule.make_sets # Randomly choose the indexes of sentences to save. ex_save_dir = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "results/") os.makedirs(ex_save_dir, exist_ok=True) metrics = MetricsTracker( save_file=os.path.join(ex_save_dir, "metrics.csv")) torch.no_grad().__enter__() with progress: for idx in progress.track(range(len(test_set))): if idx == 825: # Forward the network on the mixture. mix, sources, key = tensors_to_device(test_set[idx], device=model_device) est_sources = model(mix[None]) mix_np = mix sources_np = sources est_sources_np = est_sources.squeeze(0) # metrics(mix=mix_np, # clean=sources_np, # estimate=est_sources_np, # key=key) save_dir = os.path.join("./result/TIGER", "idx{}".format(idx)) # est_sources_np = normalize_tensor_wav(est_sources_np) for i in range(est_sources_np.shape[0]): os.makedirs(os.path.join(save_dir, "s{}/".format(i + 1)), exist_ok=True) # torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key, est_sources_np[i].unsqueeze(0).cpu(), 16000) torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key.split("/")[-1], est_sources_np[i].unsqueeze(0).cpu(), 16000) # if idx % 50 == 0: # metricscolumn.update(metrics.update()) metrics.final() if __name__ == "__main__": args = parser.parse_args() arg_dic = dict(vars(args)) # Load training config with open(args.conf_dir, "rb") as f: train_conf = yaml.safe_load(f) arg_dic["train_conf"] = train_conf # print(arg_dic) main(arg_dic)