import os import sys import torch from torch import Tensor import argparse import json import look2hear.datas import look2hear.models import look2hear.system import look2hear.losses import look2hear.metrics import look2hear.utils from look2hear.system import make_optimizer from dataclasses import dataclass from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import * from rich.console import Console from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers.wandb import WandbLogger from pytorch_lightning.strategies.ddp import DDPStrategy from rich import print, reconfigure from collections.abc import MutableMapping from look2hear.utils import print_only, MyRichProgressBar, RichProgressBarTheme import warnings warnings.filterwarnings("ignore") import wandb wandb.login() parser = argparse.ArgumentParser() parser.add_argument( "--conf_dir", default="local/conf.yml", help="Full path to save best validation model", ) def main(config): print_only( "Instantiating datamodule <{}>".format(config["datamodule"]["data_name"]) ) datamodule: object = getattr(look2hear.datas, config["datamodule"]["data_name"])( **config["datamodule"]["data_config"] ) datamodule.setup() train_loader, val_loader, test_loader = datamodule.make_loader # Define model and optimizer print_only( "Instantiating AudioNet <{}>".format(config["audionet"]["audionet_name"]) ) model = getattr(look2hear.models, config["audionet"]["audionet_name"])( sample_rate=config["datamodule"]["data_config"]["sample_rate"], **config["audionet"]["audionet_config"], ) # import pdb; pdb.set_trace() print_only("Instantiating Optimizer <{}>".format(config["optimizer"]["optim_name"])) optimizer = make_optimizer(model.parameters(), **config["optimizer"]) # Define scheduler scheduler = None if config["scheduler"]["sche_name"]: print_only( "Instantiating Scheduler <{}>".format(config["scheduler"]["sche_name"]) ) if config["scheduler"]["sche_name"] != "DPTNetScheduler": scheduler = getattr(torch.optim.lr_scheduler, config["scheduler"]["sche_name"])( optimizer=optimizer, **config["scheduler"]["sche_config"] ) else: scheduler = { "scheduler": getattr(look2hear.system.schedulers, config["scheduler"]["sche_name"])( optimizer, len(train_loader) // config["datamodule"]["data_config"]["batch_size"], 64 ), "interval": "step", } # Just after instantiating, save the args. Easy loading in the future. config["main_args"]["exp_dir"] = os.path.join( os.getcwd(), "Experiments", "checkpoint", config["exp"]["exp_name"] ) exp_dir = config["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(config, outfile) # Define Loss function. print_only( "Instantiating Loss, Train <{}>, Val <{}>".format( config["loss"]["train"]["sdr_type"], config["loss"]["val"]["sdr_type"] ) ) loss_func = { "train": getattr(look2hear.losses, config["loss"]["train"]["loss_func"])( getattr(look2hear.losses, config["loss"]["train"]["sdr_type"]), **config["loss"]["train"]["config"], ), "val": getattr(look2hear.losses, config["loss"]["val"]["loss_func"])( getattr(look2hear.losses, config["loss"]["val"]["sdr_type"]), **config["loss"]["val"]["config"], ), } print_only("Instantiating System <{}>".format(config["training"]["system"])) system = getattr(look2hear.system, config["training"]["system"])( audio_model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, scheduler=scheduler, config=config, ) # Define callbacks print_only("Instantiating ModelCheckpoint") callbacks = [] checkpoint_dir = os.path.join(exp_dir) checkpoint = ModelCheckpoint( checkpoint_dir, filename="{epoch}", monitor="val_loss/dataloader_idx_0", mode="min", save_top_k=5, verbose=True, save_last=True, ) callbacks.append(checkpoint) if config["training"]["early_stop"]: print_only("Instantiating EarlyStopping") callbacks.append(EarlyStopping(**config["training"]["early_stop"])) callbacks.append(MyRichProgressBar(theme=RichProgressBarTheme())) # Don't ask GPU if they are not available. gpus = config["training"]["gpus"] if torch.cuda.is_available() else None distributed_backend = "cuda" if torch.cuda.is_available() else None # default logger used by trainer logger_dir = os.path.join(os.getcwd(), "Experiments", "tensorboard_logs") os.makedirs(os.path.join(logger_dir, config["exp"]["exp_name"]), exist_ok=True) # comet_logger = TensorBoardLogger(logger_dir, name=config["exp"]["exp_name"]) comet_logger = WandbLogger( name=config["exp"]["exp_name"], save_dir=os.path.join(logger_dir, config["exp"]["exp_name"]), project="Real-work-dataset", # offline=True ) trainer = pl.Trainer( max_epochs=config["training"]["epochs"], callbacks=callbacks, default_root_dir=exp_dir, devices=gpus, accelerator=distributed_backend, strategy=DDPStrategy(find_unused_parameters=True), limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=5.0, logger=comet_logger, sync_batchnorm=True, # precision="bf16-mixed", # num_sanity_val_steps=0, # sync_batchnorm=True, # fast_dev_run=True, ) trainer.fit(system) print_only("Finished Training") best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.audio_model.serialize() torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) if __name__ == "__main__": import yaml from pprint import pprint from look2hear.utils.parser_utils import ( prepare_parser_from_dict, parse_args_as_dict, ) args = parser.parse_args() with open(args.conf_dir) as f: def_conf = yaml.safe_load(f) parser = prepare_parser_from_dict(def_conf, parser=parser) arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) # pprint(arg_dic) main(arg_dic)