Spaces:
Running
on
Zero
Running
on
Zero
# coding: utf-8 | |
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' | |
__version__ = '1.0.4' | |
import random | |
import argparse | |
from tqdm.auto import tqdm | |
import os | |
import torch | |
import wandb | |
import numpy as np | |
import auraloss | |
import torch.nn as nn | |
from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop | |
from torch.utils.data import DataLoader | |
from torch.cuda.amp.grad_scaler import GradScaler | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from ml_collections import ConfigDict | |
import torch.nn.functional as F | |
from typing import List, Tuple, Dict, Union, Callable, Any | |
from dataset import MSSDataset | |
from utils import get_model_from_config | |
from valid import valid_multi_gpu, valid | |
from utils import bind_lora_to_model, load_start_checkpoint | |
import loralib as lora | |
import warnings | |
warnings.filterwarnings("ignore") | |
def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace: | |
""" | |
Parse command-line arguments for configuring the model, dataset, and training parameters. | |
Args: | |
dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv. | |
Returns: | |
Namespace object containing parsed arguments and their values. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_type", type=str, default='mdx23c', | |
help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit") | |
parser.add_argument("--config_path", type=str, help="path to config file") | |
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to start training") | |
parser.add_argument("--results_path", type=str, | |
help="path to folder where results will be stored (weights, metadata)") | |
parser.add_argument("--data_path", nargs="+", type=str, help="Dataset data paths. You can provide several folders.") | |
parser.add_argument("--dataset_type", type=int, default=1, | |
help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md") | |
parser.add_argument("--valid_path", nargs="+", type=str, | |
help="validation data paths. You can provide several folders.") | |
parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers") | |
parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory") | |
parser.add_argument("--seed", type=int, default=0, help="random seed") | |
parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='list of gpu ids') | |
parser.add_argument("--loss", type=str, nargs='+', choices=['masked_loss', 'mse_loss', 'l1_loss', 'multistft_loss'], | |
default=['masked_loss'], help="List of loss functions to use") | |
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key') | |
parser.add_argument("--pre_valid", action='store_true', help='Run validation before training') | |
parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"], | |
choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', | |
'fullness'], help='List of metrics to use.') | |
parser.add_argument("--metric_for_scheduler", default="sdr", | |
choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', | |
'fullness'], help='Metric which will be used for scheduler.') | |
parser.add_argument("--train_lora", action='store_true', help="Train with LoRA") | |
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") | |
if dict_args is not None: | |
args = parser.parse_args([]) | |
args_dict = vars(args) | |
args_dict.update(dict_args) | |
args = argparse.Namespace(**args_dict) | |
else: | |
args = parser.parse_args() | |
if args.metric_for_scheduler not in args.metrics: | |
args.metrics += [args.metric_for_scheduler] | |
return args | |
def manual_seed(seed: int) -> None: | |
""" | |
Set the random seed for reproducibility across Python, NumPy, and PyTorch. | |
Args: | |
seed: The seed value to set. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) # if multi-GPU | |
torch.backends.cudnn.deterministic = True | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
def initialize_environment(seed: int, results_path: str) -> None: | |
""" | |
Initialize the environment by setting the random seed, configuring PyTorch settings, | |
and creating the results directory. | |
Args: | |
seed: The seed value for reproducibility. | |
results_path: Path to the directory where results will be stored. | |
""" | |
manual_seed(seed) | |
torch.backends.cudnn.deterministic = False | |
try: | |
torch.multiprocessing.set_start_method('spawn') | |
except Exception as e: | |
pass | |
os.makedirs(results_path, exist_ok=True) | |
def wandb_init(args: argparse.Namespace, config: Dict, device_ids: List[int], batch_size: int) -> None: | |
""" | |
Initialize the Weights & Biases (wandb) logging system. | |
Args: | |
args: Parsed command-line arguments containing the wandb key. | |
config: Configuration dictionary for the experiment. | |
device_ids: List of GPU device IDs used for training. | |
batch_size: Batch size for training. | |
""" | |
if args.wandb_key is None or args.wandb_key.strip() == '': | |
wandb.init(mode='disabled') | |
else: | |
wandb.login(key=args.wandb_key) | |
wandb.init(project='msst', config={'config': config, 'args': args, 'device_ids': device_ids, 'batch_size': batch_size }) | |
def prepare_data(config: Dict, args: argparse.Namespace, batch_size: int) -> DataLoader: | |
""" | |
Prepare the training dataset and data loader. | |
Args: | |
config: Configuration dictionary for the dataset. | |
args: Parsed command-line arguments containing dataset paths and settings. | |
batch_size: Batch size for training. | |
Returns: | |
DataLoader object for the training dataset. | |
""" | |
trainset = MSSDataset( | |
config, | |
args.data_path, | |
batch_size=batch_size, | |
metadata_path=os.path.join(args.results_path, f'metadata_{args.dataset_type}.pkl'), | |
dataset_type=args.dataset_type, | |
) | |
train_loader = DataLoader( | |
trainset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=args.num_workers, | |
pin_memory=args.pin_memory | |
) | |
return train_loader | |
def initialize_model_and_device(model: torch.nn.Module, device_ids: List[int]) -> Tuple[Union[torch.device, str], torch.nn.Module]: | |
""" | |
Initialize the model and assign it to the appropriate device (GPU or CPU). | |
Args: | |
model: The PyTorch model to be initialized. | |
device_ids: List of GPU device IDs to use for parallel processing. | |
Returns: | |
A tuple containing the device and the model moved to that device. | |
""" | |
if torch.cuda.is_available(): | |
if len(device_ids) <= 1: | |
device = torch.device(f'cuda:{device_ids[0]}') | |
model = model.to(device) | |
else: | |
device = torch.device(f'cuda:{device_ids[0]}') | |
model = nn.DataParallel(model, device_ids=device_ids).to(device) | |
else: | |
device = 'cpu' | |
model = model.to(device) | |
print("CUDA is not available. Running on CPU.") | |
return device, model | |
def get_optimizer(config: ConfigDict, model: torch.nn.Module) -> torch.optim.Optimizer: | |
""" | |
Initializes an optimizer based on the configuration. | |
Args: | |
config: Configuration object containing training parameters. | |
model: PyTorch model whose parameters will be optimized. | |
Returns: | |
A PyTorch optimizer object configured based on the specified settings. | |
""" | |
optim_params = dict() | |
if 'optimizer' in config: | |
optim_params = dict(config['optimizer']) | |
print(f'Optimizer params from config:\n{optim_params}') | |
name_optimizer = getattr(config.training, 'optimizer', | |
'No optimizer in config') | |
if name_optimizer == 'adam': | |
optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params) | |
elif name_optimizer == 'adamw': | |
optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params) | |
elif name_optimizer == 'radam': | |
optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params) | |
elif name_optimizer == 'rmsprop': | |
optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params) | |
elif name_optimizer == 'prodigy': | |
from prodigyopt import Prodigy | |
# you can choose weight decay value based on your problem, 0 by default | |
# We recommend using lr=1.0 (default) for all networks. | |
optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params) | |
elif name_optimizer == 'adamw8bit': | |
import bitsandbytes as bnb | |
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params) | |
elif name_optimizer == 'sgd': | |
print('Use SGD optimizer') | |
optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params) | |
else: | |
print(f'Unknown optimizer: {name_optimizer}') | |
exit() | |
return optimizer | |
def multistft_loss(y: torch.Tensor, y_: torch.Tensor, | |
loss_multistft: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> torch.Tensor: | |
if len(y_.shape) == 4: | |
y1_ = y_.reshape(y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3]) | |
y1 = y.reshape(y.shape[0], y.shape[1] * y.shape[2], y.shape[3]) | |
elif len(y_.shape) == 3: | |
y1_, y1 = y_, y | |
else: | |
raise ValueError(f"Invalid shape for predicted array: {y_.shape}. Expected 3 or 4 dimensions.") | |
return loss_multistft(y1_, y1) | |
def masked_loss(y_: torch.Tensor, y: torch.Tensor, q: float, coarse: bool = True) -> torch.Tensor: | |
loss = torch.nn.MSELoss(reduction='none')(y_, y).transpose(0, 1) | |
if coarse: | |
loss = loss.mean(dim=(-1, -2)) | |
loss = loss.reshape(loss.shape[0], -1) | |
quantile = torch.quantile(loss.detach(), q, interpolation='linear', dim=1, keepdim=True) | |
mask = loss < quantile | |
return (loss * mask).mean() | |
def choice_loss(args: argparse.Namespace, config: ConfigDict) -> Callable[[Any, Any], int]: | |
""" | |
Select and return the appropriate loss function based on the configuration and arguments. | |
Args: | |
args: Parsed command-line arguments containing flags for different loss functions. | |
config: Configuration object containing loss settings and parameters. | |
Returns: | |
A loss function that can be applied to the predicted and ground truth tensors. | |
""" | |
print(f'Losses for training: {args.loss}') | |
loss_fns = [] | |
if 'masked_loss' in args.loss: | |
loss_fns.append( | |
lambda y_, y: masked_loss(y_, y, q=config['training']['q'], coarse=config['training']['coarse_loss_clip'])) | |
if 'mse_loss' in args.loss: | |
loss_fns.append(nn.MSELoss()) | |
if 'l1_loss' in args.loss: | |
loss_fns.append(F.l1_loss) | |
if 'multistft_loss' in args.loss: | |
loss_options = dict(config.get('loss_multistft', {})) | |
loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(**loss_options) | |
loss_fns.append(lambda y_, y: multistft_loss(y_, y, loss_multistft) / 1000) | |
def multi_loss(y_, y): | |
return sum(loss_fn(y_, y) for loss_fn in loss_fns) | |
return multi_loss | |
def normalize_batch(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Normalize a batch of tensors (x and y) by subtracting the mean and dividing by the standard deviation. | |
Args: | |
x: Tensor to normalize. | |
y: Tensor to normalize (same as x, typically). | |
Returns: | |
A tuple of normalized tensors (x, y). | |
""" | |
mean = x.mean() | |
std = x.std() | |
if std != 0: | |
x = (x - mean) / std | |
y = (y - mean) / std | |
return x, y | |
def train_one_epoch(model: torch.nn.Module, config: ConfigDict, args: argparse.Namespace, optimizer: torch.optim.Optimizer, | |
device: torch.device, device_ids: List[int], epoch: int, use_amp: bool, scaler: torch.cuda.amp.GradScaler, | |
gradient_accumulation_steps: int, train_loader: torch.utils.data.DataLoader, | |
multi_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> None: | |
""" | |
Train the model for one epoch. | |
Args: | |
model: The model to train. | |
config: Configuration object containing training parameters. | |
args: Command-line arguments with specific settings (e.g., model type). | |
optimizer: Optimizer used for training. | |
device: Device to run the model on (CPU or GPU). | |
device_ids: List of GPU device IDs if using multiple GPUs. | |
epoch: The current epoch number. | |
use_amp: Whether to use automatic mixed precision (AMP) for training. | |
scaler: Scaler for AMP to manage gradient scaling. | |
gradient_accumulation_steps: Number of gradient accumulation steps before updating the optimizer. | |
train_loader: DataLoader for the training dataset. | |
multi_loss: The loss function to use during training. | |
Returns: | |
None | |
""" | |
model.train().to(device) | |
print(f'Train epoch: {epoch} Learning rate: {optimizer.param_groups[0]["lr"]}') | |
loss_val = 0. | |
total = 0 | |
normalize = getattr(config.training, 'normalize', False) | |
pbar = tqdm(train_loader) | |
for i, (batch, mixes) in enumerate(pbar): | |
x = mixes.to(device) # mixture | |
y = batch.to(device) | |
if normalize: | |
x, y = normalize_batch(x, y) | |
with torch.cuda.amp.autocast(enabled=use_amp): | |
if args.model_type in ['mel_band_roformer', 'bs_roformer']: | |
# loss is computed in forward pass | |
loss = model(x, y) | |
if isinstance(device_ids, (list, tuple)): | |
# If it's multiple GPUs sum partial loss | |
loss = loss.mean() | |
else: | |
y_ = model(x) | |
loss = multi_loss(y_, y) | |
loss /= gradient_accumulation_steps | |
scaler.scale(loss).backward() | |
if config.training.grad_clip: | |
nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip) | |
if ((i + 1) % gradient_accumulation_steps == 0) or (i == len(train_loader) - 1): | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad(set_to_none=True) | |
li = loss.item() * gradient_accumulation_steps | |
loss_val += li | |
total += 1 | |
pbar.set_postfix({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1)}) | |
wandb.log({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1), 'i': i}) | |
loss.detach() | |
print(f'Training loss: {loss_val / total}') | |
wandb.log({'train_loss': loss_val / total, 'epoch': epoch, 'learning_rate': optimizer.param_groups[0]['lr']}) | |
def save_weights(store_path, model, device_ids, train_lora): | |
if train_lora: | |
torch.save(lora.lora_state_dict(model), store_path) | |
else: | |
state_dict = model.state_dict() if len(device_ids) <= 1 else model.module.state_dict() | |
torch.save( | |
state_dict, | |
store_path | |
) | |
def save_last_weights(args: argparse.Namespace, model: torch.nn.Module, device_ids: List[int]) -> None: | |
""" | |
Save the model's state_dict to a file for later use. | |
Args: | |
args: Command-line arguments containing the results path and model type. | |
model: The model whose weights will be saved. | |
device_ids: List of GPU device IDs if using multiple GPUs. | |
Returns: | |
None | |
""" | |
store_path = f'{args.results_path}/last_{args.model_type}.ckpt' | |
train_lora = args.train_lora | |
save_weights(store_path, model, device_ids, train_lora) | |
def compute_epoch_metrics(model: torch.nn.Module, args: argparse.Namespace, config: ConfigDict, | |
device: torch.device, device_ids: List[int], best_metric: float, | |
epoch: int, scheduler: torch.optim.lr_scheduler._LRScheduler) -> float: | |
""" | |
Compute and log the metrics for the current epoch, and save model weights if the metric improves. | |
Args: | |
model: The model to evaluate. | |
args: Command-line arguments containing configuration paths and other settings. | |
config: Configuration dictionary containing training settings. | |
device: The device (CPU or GPU) used for evaluation. | |
device_ids: List of GPU device IDs when using multiple GPUs. | |
best_metric: The best metric value seen so far. | |
epoch: The current epoch number. | |
scheduler: The learning rate scheduler to adjust the learning rate. | |
Returns: | |
The updated best_metric. | |
""" | |
if torch.cuda.is_available() and len(device_ids) > 1: | |
metrics_avg, all_metrics = valid_multi_gpu(model, args, config, args.device_ids, verbose=False) | |
else: | |
metrics_avg, all_metrics = valid(model, args, config, device, verbose=False) | |
metric_avg = metrics_avg[args.metric_for_scheduler] | |
if metric_avg > best_metric: | |
store_path = f'{args.results_path}/model_{args.model_type}_ep_{epoch}_{args.metric_for_scheduler}_{metric_avg:.4f}.ckpt' | |
print(f'Store weights: {store_path}') | |
train_lora = args.train_lora | |
save_weights(store_path, model, device_ids, train_lora) | |
best_metric = metric_avg | |
scheduler.step(metric_avg) | |
wandb.log({'metric_main': metric_avg, 'best_metric': best_metric}) | |
for metric_name in metrics_avg: | |
wandb.log({f'metric_{metric_name}': metrics_avg[metric_name]}) | |
return best_metric | |
def train_model(args: argparse.Namespace) -> None: | |
""" | |
Trains the model based on the provided arguments, including data preparation, optimizer setup, | |
and loss calculation. The model is trained for multiple epochs with logging via wandb. | |
Args: | |
args: Command-line arguments containing configuration paths, hyperparameters, and other settings. | |
Returns: | |
None | |
""" | |
args = parse_args(args) | |
initialize_environment(args.seed, args.results_path) | |
model, config = get_model_from_config(args.model_type, args.config_path) | |
use_amp = getattr(config.training, 'use_amp', True) | |
device_ids = args.device_ids | |
batch_size = config.training.batch_size * len(device_ids) | |
wandb_init(args, config, device_ids, batch_size) | |
train_loader = prepare_data(config, args, batch_size) | |
if args.start_check_point: | |
load_start_checkpoint(args, model, type_='train') | |
if args.train_lora: | |
model = bind_lora_to_model(config, model) | |
lora.mark_only_lora_as_trainable(model) | |
device, model = initialize_model_and_device(model, args.device_ids) | |
if args.pre_valid: | |
if torch.cuda.is_available() and len(device_ids) > 1: | |
valid_multi_gpu(model, args, config, args.device_ids, verbose=True) | |
else: | |
valid(model, args, config, device, verbose=True) | |
optimizer = get_optimizer(config, model) | |
gradient_accumulation_steps = int(getattr(config.training, 'gradient_accumulation_steps', 1)) | |
# Reduce LR if no metric improvements for several epochs | |
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=config.training.patience, | |
factor=config.training.reduce_factor) | |
multi_loss = choice_loss(args, config) | |
scaler = GradScaler() | |
best_metric = float('-inf') | |
print( | |
f"Instruments: {config.training.instruments}\n" | |
f"Metrics for training: {args.metrics}. Metric for scheduler: {args.metric_for_scheduler}\n" | |
f"Patience: {config.training.patience} " | |
f"Reduce factor: {config.training.reduce_factor}\n" | |
f"Batch size: {batch_size} " | |
f"Grad accum steps: {gradient_accumulation_steps} " | |
f"Effective batch size: {batch_size * gradient_accumulation_steps}\n" | |
f"Dataset type: {args.dataset_type}\n" | |
f"Optimizer: {config.training.optimizer}" | |
) | |
print(f'Train for: {config.training.num_epochs} epochs') | |
for epoch in range(config.training.num_epochs): | |
train_one_epoch(model, config, args, optimizer, device, device_ids, epoch, | |
use_amp, scaler, gradient_accumulation_steps, train_loader, multi_loss) | |
save_last_weights(args, model, device_ids) | |
best_metric = compute_epoch_metrics(model, args, config, device, device_ids, best_metric, epoch, scheduler) | |
if __name__ == "__main__": | |
train_model(None) |