Spaces:
Running
Running
| # Deep learning | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.backends.cudnn as cudnn | |
| from torch.utils.data import DataLoader | |
| from utils import CustomDataset, CustomDatasetMultitask, RMSELoss, normalize_smiles | |
| # Data | |
| import pandas as pd | |
| import numpy as np | |
| # Standard library | |
| import random | |
| import args | |
| import os | |
| import shutil | |
| from tqdm import tqdm | |
| # Machine Learning | |
| from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score, roc_curve, auc, precision_recall_curve | |
| from scipy import stats | |
| from utils import RMSE, sensitivity, specificity | |
| class Trainer: | |
| def __init__(self, raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): | |
| # data | |
| self.df_train = raw_data[0] | |
| self.df_valid = raw_data[1] | |
| self.df_test = raw_data[2] | |
| self.dataset_name = dataset_name | |
| self.target = target | |
| self.batch_size = batch_size | |
| self.hparams = hparams | |
| self._prepare_data() | |
| # config | |
| self.target_metric = target_metric | |
| self.seed = seed | |
| self.smi_ted_version = smi_ted_version | |
| self.checkpoints_folder = checkpoints_folder | |
| self.restart_filename = restart_filename | |
| self.start_epoch = 1 | |
| self.save_every_epoch = save_every_epoch | |
| self.save_ckpt = save_ckpt | |
| self.device = device | |
| self.best_vloss = float('inf') | |
| self.last_filename = None | |
| self._set_seed(seed) | |
| def _prepare_data(self): | |
| # normalize dataset | |
| self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles) | |
| self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles) | |
| self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles) | |
| self.df_train = self.df_train.dropna(subset=['canon_smiles']) | |
| self.df_valid = self.df_valid.dropna(subset=['canon_smiles']) | |
| self.df_test = self.df_test.dropna(subset=['canon_smiles']) | |
| # create dataloader | |
| self.train_loader = DataLoader( | |
| CustomDataset(self.df_train, self.target), | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| pin_memory=True | |
| ) | |
| self.valid_loader = DataLoader( | |
| CustomDataset(self.df_valid, self.target), | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| pin_memory=True | |
| ) | |
| self.test_loader = DataLoader( | |
| CustomDataset(self.df_test, self.target), | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| pin_memory=True | |
| ) | |
| def compile(self, model, optimizer, loss_fn): | |
| self.model = model | |
| self.optimizer = optimizer | |
| self.loss_fn = loss_fn | |
| self._print_configuration() | |
| if self.restart_filename: | |
| self._load_checkpoint(self.restart_filename) | |
| print('Checkpoint restored!') | |
| def fit(self, max_epochs=500): | |
| for epoch in range(self.start_epoch, max_epochs+1): | |
| print(f'\n=====Epoch [{epoch}/{max_epochs}]=====') | |
| # training | |
| self.model.to(self.device) | |
| self.model.train() | |
| train_loss = self._train_one_epoch() | |
| # validation | |
| self.model.eval() | |
| val_preds, val_loss, val_metrics = self._validate_one_epoch(self.valid_loader) | |
| for m in val_metrics.keys(): | |
| print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}") | |
| ############################### Save Finetune checkpoint ####################################### | |
| if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt: | |
| # remove old checkpoint | |
| if (self.last_filename != None) and (not self.save_every_epoch): | |
| os.remove(os.path.join(self.checkpoints_folder, self.last_filename)) | |
| # filename | |
| model_name = f'{str(self.model)}-Finetune' | |
| self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt" | |
| # update best loss | |
| self.best_vloss = val_loss | |
| # save checkpoint | |
| print('Saving checkpoint...') | |
| self._save_checkpoint(epoch, self.last_filename) | |
| def evaluate(self, verbose=True): | |
| if verbose: | |
| print("\n=====Test Evaluation=====") | |
| if self.smi_ted_version == 'v1': | |
| import smi_ted_light.load as load | |
| elif self.smi_ted_version == 'v2': | |
| import smi_ted_large.load as load | |
| else: | |
| raise Exception('Please, specify the SMI-TED version: `v1` or `v2`.') | |
| # copy vocabulary to checkpoint folder | |
| if not os.path.exists(os.path.join(self.checkpoints_folder, 'bert_vocab_curated.txt')): | |
| smi_ted_path = os.path.dirname(load.__file__) | |
| shutil.copy(os.path.join(smi_ted_path, 'bert_vocab_curated.txt'), self.checkpoints_folder) | |
| # load model for inference | |
| model_inf = load.load_smi_ted( | |
| folder=self.checkpoints_folder, | |
| ckpt_filename=self.last_filename, | |
| eval=True, | |
| ).to(self.device) | |
| # set model evaluation mode | |
| model_inf.eval() | |
| # evaluate on test set | |
| tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf) | |
| if verbose: | |
| # show metrics | |
| for m in tst_metrics.keys(): | |
| print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}") | |
| # save predictions | |
| pd.DataFrame(tst_preds).to_csv( | |
| os.path.join( | |
| self.checkpoints_folder, | |
| f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'), | |
| index=False | |
| ) | |
| def _train_one_epoch(self): | |
| raise NotImplementedError | |
| def _validate_one_epoch(self, data_loader, model=None): | |
| raise NotImplementedError | |
| def _print_configuration(self): | |
| print('----Finetune information----') | |
| print('Dataset:\t', self.dataset_name) | |
| print('Target:\t\t', self.target) | |
| print('Batch size:\t', self.batch_size) | |
| print('LR:\t\t', self._get_lr()) | |
| print('Device:\t\t', self.device) | |
| print('Optimizer:\t', self.optimizer.__class__.__name__) | |
| print('Loss function:\t', self.loss_fn.__class__.__name__) | |
| print('Seed:\t\t', self.seed) | |
| print('Train size:\t', self.df_train.shape[0]) | |
| print('Valid size:\t', self.df_valid.shape[0]) | |
| print('Test size:\t', self.df_test.shape[0]) | |
| def _load_checkpoint(self, filename): | |
| ckpt_path = os.path.join(self.checkpoints_folder, filename) | |
| ckpt_dict = torch.load(ckpt_path, map_location='cpu') | |
| self.model.load_state_dict(ckpt_dict['MODEL_STATE']) | |
| self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1 | |
| self.best_vloss = ckpt_dict['finetune_info']['best_vloss'] | |
| def _save_checkpoint(self, current_epoch, filename): | |
| if not os.path.exists(self.checkpoints_folder): | |
| os.makedirs(self.checkpoints_folder) | |
| ckpt_dict = { | |
| 'MODEL_STATE': self.model.state_dict(), | |
| 'EPOCHS_RUN': current_epoch, | |
| 'hparams': vars(self.hparams), | |
| 'finetune_info': { | |
| 'dataset': self.dataset_name, | |
| 'target`': self.target, | |
| 'batch_size': self.batch_size, | |
| 'lr': self._get_lr(), | |
| 'device': self.device, | |
| 'optim': self.optimizer.__class__.__name__, | |
| 'loss_fn': self.loss_fn.__class__.__name__, | |
| 'train_size': self.df_train.shape[0], | |
| 'valid_size': self.df_valid.shape[0], | |
| 'test_size': self.df_test.shape[0], | |
| 'best_vloss': self.best_vloss, | |
| }, | |
| 'seed': self.seed, | |
| } | |
| assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed'] | |
| torch.save(ckpt_dict, os.path.join(self.checkpoints_folder, filename)) | |
| def _set_seed(self, value): | |
| random.seed(value) | |
| torch.manual_seed(value) | |
| np.random.seed(value) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(value) | |
| torch.cuda.manual_seed_all(value) | |
| cudnn.deterministic = True | |
| cudnn.benchmark = False | |
| def _get_lr(self): | |
| for param_group in self.optimizer.param_groups: | |
| return param_group['lr'] | |
| class TrainerRegressor(Trainer): | |
| def __init__(self, raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): | |
| super().__init__(raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device) | |
| def _train_one_epoch(self): | |
| running_loss = 0.0 | |
| for idx, data in enumerate(pbar := tqdm(self.train_loader)): | |
| # Every data instance is an input + label pair | |
| smiles, targets = data | |
| targets = targets.clone().detach().to(self.device) | |
| # zero the parameter gradients (otherwise they are accumulated) | |
| self.optimizer.zero_grad() | |
| # Make predictions for this batch | |
| embeddings = self.model.extract_embeddings(smiles).to(self.device) | |
| outputs = self.model.net(embeddings).squeeze() | |
| # Compute the loss and its gradients | |
| loss = self.loss_fn(outputs, targets) | |
| loss.backward() | |
| # Adjust learning weights | |
| self.optimizer.step() | |
| # print statistics | |
| running_loss += loss.item() | |
| # progress bar | |
| pbar.set_description('[TRAINING]') | |
| pbar.set_postfix(loss=running_loss/(idx+1)) | |
| pbar.refresh() | |
| return running_loss / len(self.train_loader) | |
| def _validate_one_epoch(self, data_loader, model=None): | |
| data_targets = [] | |
| data_preds = [] | |
| running_loss = 0.0 | |
| model = self.model if model is None else model | |
| with torch.no_grad(): | |
| for idx, data in enumerate(pbar := tqdm(data_loader)): | |
| # Every data instance is an input + label pair | |
| smiles, targets = data | |
| targets = targets.clone().detach().to(self.device) | |
| # Make predictions for this batch | |
| embeddings = model.extract_embeddings(smiles).to(self.device) | |
| predictions = model.net(embeddings).squeeze() | |
| # Compute the loss | |
| loss = self.loss_fn(predictions, targets) | |
| data_targets.append(targets.view(-1)) | |
| data_preds.append(predictions.view(-1)) | |
| # print statistics | |
| running_loss += loss.item() | |
| # progress bar | |
| pbar.set_description('[EVALUATION]') | |
| pbar.set_postfix(loss=running_loss/(idx+1)) | |
| pbar.refresh() | |
| # Put together predictions and labels from batches | |
| preds = torch.cat(data_preds, dim=0).cpu().numpy() | |
| tgts = torch.cat(data_targets, dim=0).cpu().numpy() | |
| # Compute metrics | |
| mae = mean_absolute_error(tgts, preds) | |
| r2 = r2_score(tgts, preds) | |
| rmse = RMSE(preds, tgts) | |
| spearman = stats.spearmanr(tgts, preds).statistic # scipy 1.12.0 | |
| # Rearange metrics | |
| metrics = { | |
| 'mae': mae, | |
| 'r2': r2, | |
| 'rmse': rmse, | |
| 'spearman': spearman, | |
| } | |
| return preds, running_loss / len(data_loader), metrics | |
| class TrainerClassifier(Trainer): | |
| def __init__(self, raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): | |
| super().__init__(raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device) | |
| def _train_one_epoch(self): | |
| running_loss = 0.0 | |
| for idx, data in enumerate(pbar := tqdm(self.train_loader)): | |
| # Every data instance is an input + label pair | |
| smiles, targets = data | |
| targets = targets.clone().detach().to(self.device) | |
| # zero the parameter gradients (otherwise they are accumulated) | |
| self.optimizer.zero_grad() | |
| # Make predictions for this batch | |
| embeddings = self.model.extract_embeddings(smiles).to(self.device) | |
| outputs = self.model.net(embeddings).squeeze() | |
| # Compute the loss and its gradients | |
| loss = self.loss_fn(outputs, targets.long()) | |
| loss.backward() | |
| # Adjust learning weights | |
| self.optimizer.step() | |
| # print statistics | |
| running_loss += loss.item() | |
| # progress bar | |
| pbar.set_description('[TRAINING]') | |
| pbar.set_postfix(loss=running_loss/(idx+1)) | |
| pbar.refresh() | |
| return running_loss / len(self.train_loader) | |
| def _validate_one_epoch(self, data_loader, model=None): | |
| data_targets = [] | |
| data_preds = [] | |
| running_loss = 0.0 | |
| model = self.model if model is None else model | |
| with torch.no_grad(): | |
| for idx, data in enumerate(pbar := tqdm(data_loader)): | |
| # Every data instance is an input + label pair | |
| smiles, targets = data | |
| targets = targets.clone().detach().to(self.device) | |
| # Make predictions for this batch | |
| embeddings = model.extract_embeddings(smiles).to(self.device) | |
| predictions = model.net(embeddings).squeeze() | |
| # Compute the loss | |
| loss = self.loss_fn(predictions, targets.long()) | |
| data_targets.append(targets.view(-1)) | |
| data_preds.append(predictions) | |
| # print statistics | |
| running_loss += loss.item() | |
| # progress bar | |
| pbar.set_description('[EVALUATION]') | |
| pbar.set_postfix(loss=running_loss/(idx+1)) | |
| pbar.refresh() | |
| # Put together predictions and labels from batches | |
| preds = torch.cat(data_preds, dim=0).cpu().numpy() | |
| tgts = torch.cat(data_targets, dim=0).cpu().numpy() | |
| # Compute metrics | |
| preds_cpu = F.softmax(torch.tensor(preds), dim=1).cpu().numpy()[:, 1] | |
| # accuracy | |
| y_pred = np.where(preds_cpu >= 0.5, 1, 0) | |
| accuracy = accuracy_score(tgts, y_pred) | |
| # sensitivity | |
| sn = sensitivity(tgts, y_pred) | |
| # specificity | |
| sp = specificity(tgts, y_pred) | |
| # roc-auc | |
| fpr, tpr, _ = roc_curve(tgts, preds_cpu) | |
| roc_auc = auc(fpr, tpr) | |
| # prc-auc | |
| precision, recall, _ = precision_recall_curve(tgts, preds_cpu) | |
| prc_auc = auc(recall, precision) | |
| # Rearange metrics | |
| metrics = { | |
| 'acc': accuracy, | |
| 'roc-auc': roc_auc, | |
| 'prc-auc': prc_auc, | |
| 'sensitivity': sn, | |
| 'specificity': sp, | |
| } | |
| return preds, running_loss / len(data_loader), metrics | |
| class TrainerClassifierMultitask(Trainer): | |
| def __init__(self, raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'): | |
| super().__init__(raw_data, dataset_name, target, batch_size, hparams, | |
| target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device) | |
| def _prepare_data(self): | |
| # normalize dataset | |
| self.df_train['canon_smiles'] = self.df_train['smiles'].apply(normalize_smiles) | |
| self.df_valid['canon_smiles'] = self.df_valid['smiles'].apply(normalize_smiles) | |
| self.df_test['canon_smiles'] = self.df_test['smiles'].apply(normalize_smiles) | |
| self.df_train = self.df_train.dropna(subset=['canon_smiles']) | |
| self.df_valid = self.df_valid.dropna(subset=['canon_smiles']) | |
| self.df_test = self.df_test.dropna(subset=['canon_smiles']) | |
| # create dataloader | |
| self.train_loader = DataLoader( | |
| CustomDatasetMultitask(self.df_train, self.target), | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| pin_memory=True | |
| ) | |
| self.valid_loader = DataLoader( | |
| CustomDatasetMultitask(self.df_valid, self.target), | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| pin_memory=True | |
| ) | |
| self.test_loader = DataLoader( | |
| CustomDatasetMultitask(self.df_test, self.target), | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| pin_memory=True | |
| ) | |
| def _train_one_epoch(self): | |
| running_loss = 0.0 | |
| for idx, data in enumerate(pbar := tqdm(self.train_loader)): | |
| # Every data instance is an input + label pair + mask | |
| smiles, targets, target_masks = data | |
| targets = targets.clone().detach().to(self.device) | |
| # zero the parameter gradients (otherwise they are accumulated) | |
| self.optimizer.zero_grad() | |
| # Make predictions for this batch | |
| embeddings = self.model.extract_embeddings(smiles).to(self.device) | |
| outputs = self.model.net(embeddings, multitask=True).squeeze() | |
| outputs = outputs * target_masks.to(self.device) | |
| # Compute the loss and its gradients | |
| loss = self.loss_fn(outputs, targets) | |
| loss.backward() | |
| # Adjust learning weights | |
| self.optimizer.step() | |
| # print statistics | |
| running_loss += loss.item() | |
| # progress bar | |
| pbar.set_description('[TRAINING]') | |
| pbar.set_postfix(loss=running_loss/(idx+1)) | |
| pbar.refresh() | |
| return running_loss / len(self.train_loader) | |
| def _validate_one_epoch(self, data_loader, model=None): | |
| data_targets = [] | |
| data_preds = [] | |
| data_masks = [] | |
| running_loss = 0.0 | |
| model = self.model if model is None else model | |
| with torch.no_grad(): | |
| for idx, data in enumerate(pbar := tqdm(data_loader)): | |
| # Every data instance is an input + label pair + mask | |
| smiles, targets, target_masks = data | |
| targets = targets.clone().detach().to(self.device) | |
| # Make predictions for this batch | |
| embeddings = model.extract_embeddings(smiles).to(self.device) | |
| predictions = model.net(embeddings, multitask=True).squeeze() | |
| predictions = predictions * target_masks.to(self.device) | |
| # Compute the loss | |
| loss = self.loss_fn(predictions, targets) | |
| data_targets.append(targets) | |
| data_preds.append(predictions) | |
| data_masks.append(target_masks) | |
| # print statistics | |
| running_loss += loss.item() | |
| # progress bar | |
| pbar.set_description('[EVALUATION]') | |
| pbar.set_postfix(loss=running_loss/(idx+1)) | |
| pbar.refresh() | |
| # Put together predictions and labels from batches | |
| preds = torch.cat(data_preds, dim=0) | |
| tgts = torch.cat(data_targets, dim=0) | |
| mask = torch.cat(data_masks, dim=0) | |
| mask = mask > 0 | |
| # Compute metrics | |
| roc_aucs = [] | |
| prc_aucs = [] | |
| sns = [] | |
| sps = [] | |
| num_tasks = len(self.target) | |
| for idx in range(num_tasks): | |
| actuals_task = torch.masked_select(tgts[:, idx], mask[:, idx].to(self.device)) | |
| preds_task = torch.masked_select(preds[:, idx], mask[:, idx].to(self.device)) | |
| # accuracy | |
| y_pred = np.where(preds_task.cpu().detach() >= 0.5, 1, 0) | |
| accuracy = accuracy_score(actuals_task.cpu().numpy(), y_pred) | |
| # sensitivity | |
| sn = sensitivity(actuals_task.cpu().numpy(), y_pred) | |
| # specificity | |
| sp = specificity(actuals_task.cpu().numpy(), y_pred) | |
| # roc-auc | |
| roc_auc = roc_auc_score(actuals_task.cpu().numpy(), preds_task.cpu().numpy()) | |
| # prc-auc | |
| precision, recall, thresholds = precision_recall_curve(actuals_task.cpu().numpy(), preds_task.cpu().numpy()) | |
| prc_auc = auc(recall, precision) | |
| # append | |
| sns.append(sn) | |
| sps.append(sp) | |
| roc_aucs.append(roc_auc) | |
| prc_aucs.append(prc_auc) | |
| average_sn = torch.mean(torch.tensor(sns)) | |
| average_sp = torch.mean(torch.tensor(sps)) | |
| average_roc_auc = torch.mean(torch.tensor(roc_aucs)) | |
| average_prc_auc = torch.mean(torch.tensor(prc_aucs)) | |
| # Rearange metrics | |
| metrics = { | |
| 'acc': accuracy, | |
| 'roc-auc': average_roc_auc.item(), | |
| 'prc-auc': average_prc_auc.item(), | |
| 'sensitivity': average_sn.item(), | |
| 'specificity': average_sp.item(), | |
| } | |
| return preds.cpu().numpy(), running_loss / len(data_loader), metrics |