|
import wandb |
|
import argparse |
|
import pytorch_lightning as pl |
|
|
|
from argparse import ArgumentParser |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from os.path import join |
|
|
|
|
|
from sgmse.util.other import set_torch_cuda_arch_list |
|
set_torch_cuda_arch_list() |
|
|
|
from sgmse.backbones.shared import BackboneRegistry |
|
from sgmse.data_module import SpecsDataModule |
|
from sgmse.sdes import SDERegistry |
|
from sgmse.model import ScoreModel |
|
|
|
|
|
def get_argparse_groups(parser): |
|
groups = {} |
|
for group in parser._action_groups: |
|
group_dict = { a.dest: getattr(args, a.dest, None) for a in group._group_actions } |
|
groups[group.title] = argparse.Namespace(**group_dict) |
|
return groups |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
base_parser = ArgumentParser(add_help=False) |
|
parser = ArgumentParser() |
|
for parser_ in (base_parser, parser): |
|
parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp") |
|
parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve") |
|
parser_.add_argument("--nolog", action='store_true', help="Turn off logging.") |
|
parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.") |
|
parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.") |
|
parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.") |
|
|
|
temp_args, _ = base_parser.parse_known_args() |
|
|
|
|
|
backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone) |
|
sde_class = SDERegistry.get_by_name(temp_args.sde) |
|
trainer_parser = parser.add_argument_group("Trainer", description="Lightning Trainer") |
|
trainer_parser.add_argument("--accelerator", type=str, default="gpu", help="Supports passing different accelerator types.") |
|
trainer_parser.add_argument("--devices", default="auto", help="How many gpus to use.") |
|
trainer_parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Accumulate gradients.") |
|
|
|
ScoreModel.add_argparse_args( |
|
parser.add_argument_group("ScoreModel", description=ScoreModel.__name__)) |
|
sde_class.add_argparse_args( |
|
parser.add_argument_group("SDE", description=sde_class.__name__)) |
|
backbone_cls.add_argparse_args( |
|
parser.add_argument_group("Backbone", description=backbone_cls.__name__)) |
|
|
|
data_module_cls = SpecsDataModule |
|
data_module_cls.add_argparse_args( |
|
parser.add_argument_group("DataModule", description=data_module_cls.__name__)) |
|
|
|
args = parser.parse_args() |
|
arg_groups = get_argparse_groups(parser) |
|
|
|
|
|
model = ScoreModel( |
|
backbone=args.backbone, sde=args.sde, data_module_cls=data_module_cls, |
|
**{ |
|
**vars(arg_groups['ScoreModel']), |
|
**vars(arg_groups['SDE']), |
|
**vars(arg_groups['Backbone']), |
|
**vars(arg_groups['DataModule']) |
|
} |
|
) |
|
|
|
|
|
if args.nolog: |
|
logger = None |
|
else: |
|
logger = WandbLogger(project="sgmse", log_model=True, save_dir="logs", name=args.wandb_name) |
|
logger.experiment.log_code(".") |
|
|
|
|
|
if logger != None: |
|
callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')] |
|
if args.num_eval_files: |
|
checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), |
|
save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') |
|
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), |
|
save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') |
|
callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] |
|
else: |
|
callbacks = None |
|
|
|
|
|
trainer = pl.Trainer( |
|
**vars(arg_groups['Trainer']), |
|
strategy="ddp", logger=logger, |
|
log_every_n_steps=10, num_sanity_val_steps=0, |
|
callbacks=callbacks |
|
) |
|
|
|
|
|
trainer.fit(model, ckpt_path=args.ckpt) |