Spaces:
Paused
Paused
| import argparse | |
| import logging | |
| import os | |
| import pathlib | |
| from typing import List, NoReturn | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.strategies import DDPStrategy | |
| from torch.utils.tensorboard import SummaryWriter | |
| from data.datamodules import * | |
| from utils import create_logging, parse_yaml | |
| from models.resunet import * | |
| from losses import get_loss_function | |
| from models.audiosep import AudioSep, get_model_class | |
| from data.waveform_mixers import SegmentMixer | |
| from models.clap_encoder import CLAP_Encoder | |
| from callbacks.base import CheckpointEveryNSteps | |
| from optimizers.lr_schedulers import get_lr_lambda | |
| def get_dirs( | |
| workspace: str, | |
| filename: str, | |
| config_yaml: str, | |
| devices_num: int | |
| ) -> List[str]: | |
| r"""Get directories and paths. | |
| Args: | |
| workspace (str): directory of workspace | |
| filename (str): filename of current .py file. | |
| config_yaml (str): config yaml path | |
| devices_num (int): 0 for cpu and 8 for training with 8 GPUs | |
| Returns: | |
| checkpoints_dir (str): directory to save checkpoints | |
| logs_dir (str), directory to save logs | |
| tf_logs_dir (str), directory to save TensorBoard logs | |
| statistics_path (str), directory to save statistics | |
| """ | |
| os.makedirs(workspace, exist_ok=True) | |
| yaml_name = pathlib.Path(config_yaml).stem | |
| # Directory to save checkpoints | |
| checkpoints_dir = os.path.join( | |
| workspace, | |
| "checkpoints", | |
| filename, | |
| "{},devices={}".format(yaml_name, devices_num), | |
| ) | |
| os.makedirs(checkpoints_dir, exist_ok=True) | |
| # Directory to save logs | |
| logs_dir = os.path.join( | |
| workspace, | |
| "logs", | |
| filename, | |
| "{},devices={}".format(yaml_name, devices_num), | |
| ) | |
| os.makedirs(logs_dir, exist_ok=True) | |
| # Directory to save TensorBoard logs | |
| create_logging(logs_dir, filemode="w") | |
| logging.info(args) | |
| tf_logs_dir = os.path.join( | |
| workspace, | |
| "tf_logs", | |
| filename, | |
| "{},devices={}".format(yaml_name, devices_num), | |
| ) | |
| # Directory to save statistics | |
| statistics_path = os.path.join( | |
| workspace, | |
| "statistics", | |
| filename, | |
| "{},devices={}".format(yaml_name, devices_num), | |
| "statistics.pkl", | |
| ) | |
| os.makedirs(os.path.dirname(statistics_path), exist_ok=True) | |
| return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path | |
| def get_data_module( | |
| config_yaml: str, | |
| num_workers: int, | |
| batch_size: int, | |
| ) -> DataModule: | |
| r"""Create data_module. Mini-batch data can be obtained by: | |
| code-block:: python | |
| data_module.setup() | |
| for batch_data_dict in data_module.train_dataloader(): | |
| print(batch_data_dict.keys()) | |
| break | |
| Args: | |
| workspace: str | |
| config_yaml: str | |
| num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores | |
| for preparing data in parallel | |
| distributed: bool | |
| Returns: | |
| data_module: DataModule | |
| """ | |
| # read configurations | |
| configs = parse_yaml(config_yaml) | |
| sampling_rate = configs['data']['sampling_rate'] | |
| segment_seconds = configs['data']['segment_seconds'] | |
| # audio-text datasets | |
| datafiles = configs['data']['datafiles'] | |
| # dataset | |
| dataset = AudioTextDataset( | |
| datafiles=datafiles, | |
| sampling_rate=sampling_rate, | |
| max_clip_len=segment_seconds, | |
| ) | |
| # data module | |
| data_module = DataModule( | |
| train_dataset=dataset, | |
| num_workers=num_workers, | |
| batch_size=batch_size | |
| ) | |
| return data_module | |
| def train(args) -> NoReturn: | |
| r"""Train, evaluate, and save checkpoints. | |
| Args: | |
| workspace: str, directory of workspace | |
| gpus: int, number of GPUs to train | |
| config_yaml: str | |
| """ | |
| # arguments & parameters | |
| workspace = args.workspace | |
| config_yaml = args.config_yaml | |
| filename = args.filename | |
| devices_num = torch.cuda.device_count() | |
| # Read config file. | |
| configs = parse_yaml(config_yaml) | |
| # Configuration of data | |
| max_mix_num = configs['data']['max_mix_num'] | |
| sampling_rate = configs['data']['sampling_rate'] | |
| lower_db = configs['data']['loudness_norm']['lower_db'] | |
| higher_db = configs['data']['loudness_norm']['higher_db'] | |
| # Configuration of the separation model | |
| query_net = configs['model']['query_net'] | |
| model_type = configs['model']['model_type'] | |
| input_channels = configs['model']['input_channels'] | |
| output_channels = configs['model']['output_channels'] | |
| condition_size = configs['model']['condition_size'] | |
| use_text_ratio = configs['model']['use_text_ratio'] | |
| # Configuration of the trainer | |
| num_nodes = configs['train']['num_nodes'] | |
| batch_size = configs['train']['batch_size_per_device'] | |
| sync_batchnorm = configs['train']['sync_batchnorm'] | |
| num_workers = configs['train']['num_workers'] | |
| loss_type = configs['train']['loss_type'] | |
| optimizer_type = configs["train"]["optimizer"]["optimizer_type"] | |
| learning_rate = float(configs['train']["optimizer"]['learning_rate']) | |
| lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] | |
| warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] | |
| reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] | |
| save_step_frequency = configs['train']['save_step_frequency'] | |
| resume_checkpoint_path = args.resume_checkpoint_path | |
| if resume_checkpoint_path == "": | |
| resume_checkpoint_path = None | |
| else: | |
| logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') | |
| # Get directories and paths | |
| checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( | |
| workspace, filename, config_yaml, devices_num, | |
| ) | |
| logging.info(configs) | |
| # data module | |
| data_module = get_data_module( | |
| config_yaml=config_yaml, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| ) | |
| # model | |
| Model = get_model_class(model_type=model_type) | |
| ss_model = Model( | |
| input_channels=input_channels, | |
| output_channels=output_channels, | |
| condition_size=condition_size, | |
| ) | |
| # loss function | |
| loss_function = get_loss_function(loss_type) | |
| segment_mixer = SegmentMixer( | |
| max_mix_num=max_mix_num, | |
| lower_db=lower_db, | |
| higher_db=higher_db | |
| ) | |
| if query_net == 'CLAP': | |
| query_encoder = CLAP_Encoder() | |
| else: | |
| raise NotImplementedError | |
| lr_lambda_func = get_lr_lambda( | |
| lr_lambda_type=lr_lambda_type, | |
| warm_up_steps=warm_up_steps, | |
| reduce_lr_steps=reduce_lr_steps, | |
| ) | |
| # pytorch-lightning model | |
| pl_model = AudioSep( | |
| ss_model=ss_model, | |
| waveform_mixer=segment_mixer, | |
| query_encoder=query_encoder, | |
| loss_function=loss_function, | |
| optimizer_type=optimizer_type, | |
| learning_rate=learning_rate, | |
| lr_lambda_func=lr_lambda_func, | |
| use_text_ratio=use_text_ratio | |
| ) | |
| checkpoint_every_n_steps = CheckpointEveryNSteps( | |
| checkpoints_dir=checkpoints_dir, | |
| save_step_frequency=save_step_frequency, | |
| ) | |
| summary_writer = SummaryWriter(log_dir=tf_logs_dir) | |
| callbacks = [checkpoint_every_n_steps] | |
| trainer = pl.Trainer( | |
| accelerator='auto', | |
| devices='auto', | |
| strategy='ddp_find_unused_parameters_true', | |
| num_nodes=num_nodes, | |
| precision="32-true", | |
| logger=None, | |
| callbacks=callbacks, | |
| fast_dev_run=False, | |
| max_epochs=-1, | |
| log_every_n_steps=50, | |
| use_distributed_sampler=True, | |
| sync_batchnorm=sync_batchnorm, | |
| num_sanity_val_steps=2, | |
| enable_checkpointing=False, | |
| enable_progress_bar=True, | |
| enable_model_summary=True, | |
| ) | |
| # Fit, evaluate, and save checkpoints. | |
| trainer.fit( | |
| model=pl_model, | |
| train_dataloaders=None, | |
| val_dataloaders=None, | |
| datamodule=data_module, | |
| ckpt_path=resume_checkpoint_path, | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--workspace", type=str, required=True, help="Directory of workspace." | |
| ) | |
| parser.add_argument( | |
| "--config_yaml", | |
| type=str, | |
| required=True, | |
| help="Path of config file for training.", | |
| ) | |
| parser.add_argument( | |
| "--resume_checkpoint_path", | |
| type=str, | |
| required=True, | |
| default='', | |
| help="Path of pretrained checkpoint for finetuning.", | |
| ) | |
| args = parser.parse_args() | |
| args.filename = pathlib.Path(__file__).stem | |
| train(args) |