Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.cuda.amp import autocast | |
| import numpy as np | |
| import time | |
| import os | |
| import yaml | |
| from matplotlib import pyplot as plt | |
| import glob | |
| from collections import OrderedDict | |
| from tqdm import tqdm | |
| import torch.distributed as dist | |
| class Trainer(object): | |
| """ | |
| A class that encapsulates the training loop for a PyTorch model. | |
| """ | |
| def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2, | |
| scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None, | |
| grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None, | |
| cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1, | |
| update_func=lambda x: x): | |
| self.model = model | |
| self.optimizer = optimizer | |
| self.criterion = criterion | |
| self.scaler = scaler | |
| self.grad_clip = grad_clip | |
| self.cos_inc = cos_inc | |
| self.output_dim = output_dim | |
| self.scheduler = scheduler | |
| self.train_dl = train_dataloader | |
| self.val_dl = val_dataloader | |
| self.train_sampler = self.get_sampler_from_dataloader(train_dataloader) | |
| self.val_sampler = self.get_sampler_from_dataloader(val_dataloader) | |
| self.max_iter = max_iter | |
| self.device = device | |
| self.world_size = world_size | |
| self.exp_num = exp_num | |
| self.exp_name = exp_name | |
| self.log_path = log_path | |
| self.best_state_dict = None | |
| self.plot_every = plot_every | |
| self.logger = None | |
| self.range_update = range_update | |
| self.accumulation_step = accumulation_step | |
| self.wandb = wandb_log | |
| self.num_quantiles = num_quantiles | |
| self.update_func = update_func | |
| # if log_path is not None: | |
| # self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}') | |
| # # print(f"logger path: {self.log_path}/exp{self.exp_num}") | |
| # print("logger is: ", self.logger) | |
| def get_sampler_from_dataloader(self, dataloader): | |
| if hasattr(dataloader, 'sampler'): | |
| if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): | |
| return dataloader.sampler | |
| elif hasattr(dataloader.sampler, 'sampler'): | |
| return dataloader.sampler.sampler | |
| if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'): | |
| return dataloader.batch_sampler.sampler | |
| return None | |
| def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False): | |
| """ | |
| Fits the model for the given number of epochs. | |
| """ | |
| min_loss = np.inf | |
| best_acc = 0 | |
| train_loss, val_loss, = [], [] | |
| train_acc, val_acc = [], [] | |
| lrs = [] | |
| # self.optim_params['lr_history'] = [] | |
| epochs_without_improvement = 0 | |
| # main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu' | |
| main_proccess = True # change in a ddp setting | |
| print(f"Starting training for {num_epochs} epochs") | |
| print("is main process: ", main_proccess, flush=True) | |
| global_time = time.time() | |
| self.epoch = 0 | |
| for epoch in range(num_epochs): | |
| self.epoch = epoch | |
| start_time = time.time() | |
| plot = (self.plot_every is not None) and (epoch % self.plot_every == 0) | |
| t_loss, t_acc = self.train_epoch(device, epoch=epoch) | |
| t_loss_mean = np.nanmean(t_loss) | |
| train_loss.extend(t_loss) | |
| global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean) | |
| if main_proccess: # Only perform this on the master GPU | |
| train_acc.append(global_train_accuracy.mean().item()) | |
| v_loss, v_acc = self.eval_epoch(device, epoch=epoch) | |
| v_loss_mean = np.nanmean(v_loss) | |
| val_loss.extend(v_loss) | |
| global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean) | |
| if main_proccess: # Only perform this on the master GPU | |
| val_acc.append(global_val_accuracy.mean().item()) | |
| current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean() | |
| improved = False | |
| if best == 'loss': | |
| if current_objective < min_loss: | |
| min_loss = current_objective | |
| improved = True | |
| else: | |
| if current_objective > best_acc: | |
| best_acc = current_objective | |
| improved = True | |
| if improved: | |
| model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth' | |
| print(f"saving model at {model_name}...") | |
| torch.save(self.model.state_dict(), model_name) | |
| self.best_state_dict = self.model.state_dict() | |
| epochs_without_improvement = 0 | |
| else: | |
| epochs_without_improvement += 1 | |
| current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \ | |
| else self.scheduler.get_last_lr()[0] | |
| lrs.append(current_lr) | |
| print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\ | |
| f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\ | |
| f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\ | |
| f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True) | |
| if epoch % 10 == 0: | |
| print(os.system('nvidia-smi')) | |
| if epochs_without_improvement == early_stopping: | |
| print('early stopping!', flush=True) | |
| break | |
| if time.time() - global_time > (23.83 * 3600): | |
| print("time limit reached") | |
| break | |
| return {"num_epochs":num_epochs, "train_loss": train_loss, | |
| "val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs} | |
| def process_loss(self, acc, loss_mean): | |
| if torch.cuda.is_available() and torch.distributed.is_initialized(): | |
| global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU | |
| torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU | |
| torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
| # Divide both loss and accuracy by world size | |
| world_size = torch.distributed.get_world_size() | |
| global_loss /= world_size | |
| global_accuracy /= world_size | |
| else: | |
| global_loss = torch.tensor(loss_mean) | |
| global_accuracy = torch.tensor(acc) | |
| return global_accuracy, global_loss | |
| def load_best_model(self, to_ddp=True, from_ddp=True): | |
| data_dir = f'{self.log_path}/exp{self.exp_num}' | |
| # data_dir = f'{self.log_path}/exp29' # for debugging | |
| state_dict_files = glob.glob(data_dir + '/*.pth') | |
| print("loading model from ", state_dict_files[-1]) | |
| state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device) | |
| if from_ddp: | |
| print("loading distributed model") | |
| # Remove "module." from keys | |
| new_state_dict = OrderedDict() | |
| for key, value in state_dict.items(): | |
| if key.startswith('module.'): | |
| while key.startswith('module.'): | |
| key = key[7:] | |
| new_state_dict[key] = value | |
| state_dict = new_state_dict | |
| # print("state_dict: ", state_dict.keys()) | |
| # print("model: ", self.model.state_dict().keys()) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| def check_gradients(self): | |
| for name, param in self.model.named_parameters(): | |
| if param.grad is not None: | |
| grad_norm = param.grad.norm().item() | |
| if grad_norm > 10: | |
| print(f"Large gradient in {name}: {grad_norm}") | |
| def train_epoch(self, device, epoch): | |
| """ | |
| Trains the model for one epoch. | |
| """ | |
| if self.train_sampler is not None: | |
| try: | |
| self.train_sampler.set_epoch(epoch) | |
| except AttributeError: | |
| pass | |
| self.model.train() | |
| train_loss = [] | |
| train_acc = 0 | |
| total = 0 | |
| all_accs = torch.zeros(self.output_dim, device=device) | |
| pbar = tqdm(self.train_dl) | |
| for i, batch in enumerate(pbar): | |
| if self.optimizer is not None: | |
| self.optimizer.zero_grad() | |
| loss, acc , y = self.train_batch(batch, i, device) | |
| train_loss.append(loss.item()) | |
| all_accs = all_accs + acc | |
| total += len(y) | |
| pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}") | |
| if i > self.max_iter: | |
| break | |
| print("number of train_accs: ", train_acc) | |
| return train_loss, all_accs/total | |
| def train_batch(self, batch, batch_idx, device): | |
| x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] | |
| x = x.to(device).float() | |
| fft = fft.to(device).float() | |
| y = y.to(device).float() | |
| x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
| y_pred = self.model(x_fft).squeeze() | |
| loss = self.criterion(y_pred, y) | |
| loss.backward() | |
| self.optimizer.step() | |
| if self.scheduler is not None: | |
| self.scheduler.step() | |
| # get predicted classes | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| return loss, acc, y | |
| def eval_epoch(self, device, epoch): | |
| """ | |
| Evaluates the model for one epoch. | |
| """ | |
| self.model.eval() | |
| val_loss = [] | |
| val_acc = 0 | |
| total = 0 | |
| all_accs = torch.zeros(self.output_dim, device=device) | |
| pbar = tqdm(self.val_dl) | |
| for i,batch in enumerate(pbar): | |
| loss, acc, y = self.eval_batch(batch, i, device) | |
| val_loss.append(loss.item()) | |
| all_accs = all_accs + acc | |
| total += len(y) | |
| pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}") | |
| if i > self.max_iter: | |
| break | |
| return val_loss, all_accs/total | |
| def eval_batch(self, batch, batch_idx, device): | |
| x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] | |
| x = x.to(device).float() | |
| fft = fft.to(device).float() | |
| x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
| y = y.to(device).float() | |
| with torch.no_grad(): | |
| y_pred = self.model(x_fft).squeeze() | |
| loss = self.criterion(y_pred.squeeze(), y) | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| return loss, acc, y | |
| def predict(self, test_dataloader, device): | |
| """ | |
| Returns the predictions of the model on the given dataset. | |
| """ | |
| self.model.eval() | |
| total = 0 | |
| all_accs = 0 | |
| predictions = [] | |
| true_labels = [] | |
| pbar = tqdm(test_dataloader) | |
| for i,batch in enumerate(pbar): | |
| x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] | |
| x = x.to(device).float() | |
| fft = fft.to(device).float() | |
| x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
| y = y.to(device).float() | |
| with torch.no_grad(): | |
| y_pred = self.model(x_fft).squeeze() | |
| loss = self.criterion(y_pred, y) | |
| probs = torch.sigmoid(y_pred) | |
| cls_pred = (probs > 0.5).float() | |
| acc = (cls_pred == y).sum() | |
| predictions.extend(cls_pred.cpu().numpy()) | |
| true_labels.extend(y.cpu().numpy()) | |
| all_accs += acc | |
| total += len(y) | |
| pbar.set_description("acc: {:.4f}".format(acc)) | |
| if i > self.max_iter: | |
| break | |
| return predictions, true_labels, all_accs/total |