|
import math |
|
import os |
|
import torch |
|
from src.modules.optimizers import * |
|
from src.modules.embeddings import * |
|
from src.modules.schedulers import * |
|
from src.modules.tokenizers import * |
|
from src.modules.metrics import * |
|
from src.modules.losses import * |
|
from src.utils.misc import * |
|
from src.utils.logger import Logger |
|
from src.utils.mapper import configmapper |
|
from src.utils.configuration import Config |
|
|
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
|
|
@configmapper.map("trainers", "base") |
|
class BaseTrainer: |
|
def __init__(self, config): |
|
self._config = config |
|
self.metrics = { |
|
configmapper.get_object("metrics", metric["type"]): metric["params"] |
|
for metric in self._config.main_config.metrics |
|
} |
|
self.train_config = self._config.train |
|
self.val_config = self._config.val |
|
self.log_label = self.train_config.log.log_label |
|
if self.train_config.log_and_val_interval is not None: |
|
self.val_log_together = True |
|
print("Logging with label: ", self.log_label) |
|
|
|
def train(self, model, train_dataset, val_dataset=None, logger=None): |
|
device = torch.device(self._config.main_config.device.name) |
|
model.to(device) |
|
optim_params = self.train_config.optimizer.params |
|
if optim_params: |
|
optimizer = configmapper.get_object( |
|
"optimizers", self.train_config.optimizer.type |
|
)(model.parameters(), **optim_params.as_dict()) |
|
else: |
|
optimizer = configmapper.get_object( |
|
"optimizers", self.train_config.optimizer.type |
|
)(model.parameters()) |
|
|
|
if self.train_config.scheduler is not None: |
|
scheduler_params = self.train_config.scheduler.params |
|
if scheduler_params: |
|
scheduler = configmapper.get_object( |
|
"schedulers", self.train_config.scheduler.type |
|
)(optimizer, **scheduler_params.as_dict()) |
|
else: |
|
scheduler = configmapper.get_object( |
|
"schedulers", self.train_config.scheduler.type |
|
)(optimizer) |
|
|
|
criterion_params = self.train_config.criterion.params |
|
if criterion_params: |
|
criterion = configmapper.get_object( |
|
"losses", self.train_config.criterion.type |
|
)(**criterion_params.as_dict()) |
|
else: |
|
criterion = configmapper.get_object( |
|
"losses", self.train_config.criterion.type |
|
)() |
|
if "custom_collate_fn" in dir(train_dataset): |
|
train_loader = DataLoader( |
|
dataset=train_dataset, |
|
collate_fn=train_dataset.custom_collate_fn, |
|
**self.train_config.loader_params.as_dict(), |
|
) |
|
else: |
|
train_loader = DataLoader( |
|
dataset=train_dataset, **self.train_config.loader_params.as_dict() |
|
) |
|
|
|
|
|
max_epochs = self.train_config.max_epochs |
|
batch_size = self.train_config.loader_params.batch_size |
|
|
|
if self.val_log_together: |
|
val_interval = self.train_config.log_and_val_interval |
|
log_interval = val_interval |
|
else: |
|
val_interval = self.train_config.val_interval |
|
log_interval = self.train_config.log.log_interval |
|
|
|
if logger is None: |
|
train_logger = Logger(**self.train_config.log.logger_params.as_dict()) |
|
else: |
|
train_logger = logger |
|
|
|
train_log_values = self.train_config.log.values.as_dict() |
|
|
|
best_score = ( |
|
-math.inf if self.train_config.save_on.desired == "max" else math.inf |
|
) |
|
save_on_score = self.train_config.save_on.score |
|
best_step = -1 |
|
best_model = None |
|
|
|
best_hparam_list = None |
|
best_hparam_name_list = None |
|
best_metrics_list = None |
|
best_metrics_name_list = None |
|
|
|
|
|
|
|
|
|
global_step = 0 |
|
for epoch in range(1, max_epochs + 1): |
|
print( |
|
"Epoch: {}/{}, Global Step: {}".format(epoch, max_epochs, global_step) |
|
) |
|
train_loss = 0 |
|
val_loss = 0 |
|
|
|
if(self.train_config.label_type=='float'): |
|
all_labels = torch.FloatTensor().to(device) |
|
else: |
|
all_labels = torch.LongTensor().to(device) |
|
|
|
all_outputs = torch.Tensor().to(device) |
|
|
|
train_scores = None |
|
val_scores = None |
|
|
|
pbar = tqdm(total=math.ceil(len(train_dataset) / batch_size)) |
|
pbar.set_description("Epoch " + str(epoch)) |
|
|
|
val_counter = 0 |
|
|
|
for step, batch in enumerate(train_loader): |
|
model.train() |
|
optimizer.zero_grad() |
|
inputs, labels = batch |
|
|
|
if(self.train_config.label_type=='float'): |
|
labels = labels.float() |
|
|
|
for key in inputs: |
|
inputs[key] = inputs[key].to(device) |
|
labels = labels.to(device) |
|
outputs = model(inputs) |
|
loss = criterion(torch.squeeze(outputs), labels) |
|
loss.backward() |
|
|
|
all_labels = torch.cat((all_labels, labels), 0) |
|
|
|
if (self.train_config.label_type=='float'): |
|
all_outputs = torch.cat((all_outputs, outputs), 0) |
|
else: |
|
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0) |
|
|
|
|
|
train_loss += loss.item() |
|
optimizer.step() |
|
|
|
if self.train_config.scheduler is not None: |
|
if isinstance(scheduler, ReduceLROnPlateau): |
|
scheduler.step(train_loss / (step + 1)) |
|
else: |
|
scheduler.step() |
|
|
|
|
|
|
|
|
|
pbar.set_postfix_str(f"Train Loss: {train_loss /(step+1)}") |
|
pbar.update(1) |
|
|
|
global_step += 1 |
|
|
|
|
|
|
|
if val_dataset is not None and (global_step - 1) % val_interval == 0: |
|
|
|
val_scores = self.val( |
|
model, |
|
val_dataset, |
|
criterion, |
|
device, |
|
global_step, |
|
train_logger, |
|
train_log_values, |
|
) |
|
|
|
|
|
if self.train_config.save_on is not None: |
|
|
|
|
|
|
|
train_scores = self.get_scores( |
|
train_loss, |
|
global_step, |
|
self.train_config.criterion.type, |
|
all_outputs, |
|
all_labels, |
|
) |
|
|
|
best_score, best_step, save_flag = self.check_best( |
|
val_scores, save_on_score, best_score, global_step |
|
) |
|
|
|
store_dict = { |
|
"model_state_dict": model.state_dict(), |
|
"best_step": best_step, |
|
"best_score": best_score, |
|
"save_on_score": save_on_score, |
|
} |
|
|
|
path = self.train_config.save_on.best_path.format( |
|
self.log_label |
|
) |
|
|
|
self.save(store_dict, path, save_flag) |
|
|
|
if save_flag and train_log_values["hparams"] is not None: |
|
( |
|
best_hparam_list, |
|
best_hparam_name_list, |
|
best_metrics_list, |
|
best_metrics_name_list, |
|
) = self.update_hparams( |
|
train_scores, val_scores, desc="best_val" |
|
) |
|
|
|
if (global_step - 1) % log_interval == 0: |
|
|
|
train_loss_name = self.train_config.criterion.type |
|
metric_list = [ |
|
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric]) |
|
for metric in self.metrics |
|
] |
|
metric_name_list = [ |
|
metric['type'] for metric in self._config.main_config.metrics |
|
] |
|
|
|
train_scores = self.log( |
|
train_loss / (step + 1), |
|
train_loss_name, |
|
metric_list, |
|
metric_name_list, |
|
train_logger, |
|
train_log_values, |
|
global_step, |
|
append_text=self.train_config.append_text, |
|
) |
|
pbar.close() |
|
if not os.path.exists(self.train_config.checkpoint.checkpoint_dir): |
|
os.makedirs(self.train_config.checkpoint.checkpoint_dir) |
|
|
|
if self.train_config.save_after_epoch: |
|
store_dict = { |
|
"model_state_dict": model.state_dict(), |
|
} |
|
|
|
path = f"{self.train_config.checkpoint.checkpoint_dir}_{str(self.train_config.log.log_label)}_{str(epoch)}.pth" |
|
|
|
self.save(store_dict, path, save_flag=1) |
|
|
|
if epoch == max_epochs: |
|
|
|
val_scores = self.val( |
|
model, |
|
val_dataset, |
|
criterion, |
|
device, |
|
global_step, |
|
train_logger, |
|
train_log_values, |
|
) |
|
|
|
|
|
train_loss_name = self.train_config.criterion.type |
|
metric_list = [ |
|
metric(all_labels.cpu(), all_outputs.detach().cpu(),**self.metrics[metric]) |
|
for metric in self.metrics |
|
] |
|
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics] |
|
|
|
train_scores = self.log( |
|
train_loss / len(train_loader), |
|
train_loss_name, |
|
metric_list, |
|
metric_name_list, |
|
train_logger, |
|
train_log_values, |
|
global_step, |
|
append_text=self.train_config.append_text, |
|
) |
|
|
|
if self.train_config.save_on is not None: |
|
|
|
|
|
|
|
train_scores = self.get_scores( |
|
train_loss, |
|
len(train_loader), |
|
self.train_config.criterion.type, |
|
all_outputs, |
|
all_labels, |
|
) |
|
|
|
best_score, best_step, save_flag = self.check_best( |
|
val_scores, save_on_score, best_score, global_step |
|
) |
|
|
|
store_dict = { |
|
"model_state_dict": model.state_dict(), |
|
"best_step": best_step, |
|
"best_score": best_score, |
|
"save_on_score": save_on_score, |
|
} |
|
|
|
path = self.train_config.save_on.best_path.format(self.log_label) |
|
|
|
self.save(store_dict, path, save_flag) |
|
|
|
if save_flag and train_log_values["hparams"] is not None: |
|
( |
|
best_hparam_list, |
|
best_hparam_name_list, |
|
best_metrics_list, |
|
best_metrics_name_list, |
|
) = self.update_hparams(train_scores, val_scores, desc="best_val") |
|
|
|
|
|
train_scores = self.get_scores( |
|
train_loss, |
|
len(train_loader), |
|
self.train_config.criterion.type, |
|
all_outputs, |
|
all_labels, |
|
) |
|
|
|
store_dict = { |
|
"model_state_dict": model.state_dict(), |
|
"final_step": global_step, |
|
"final_score": train_scores[save_on_score], |
|
"save_on_score": save_on_score, |
|
} |
|
|
|
path = self.train_config.save_on.final_path.format(self.log_label) |
|
|
|
self.save(store_dict, path, save_flag=1) |
|
if train_log_values["hparams"] is not None: |
|
( |
|
final_hparam_list, |
|
final_hparam_name_list, |
|
final_metrics_list, |
|
final_metrics_name_list, |
|
) = self.update_hparams(train_scores, val_scores, desc="final") |
|
train_logger.save_hyperparams( |
|
best_hparam_list, |
|
best_hparam_name_list, |
|
[int(self.log_label),] + best_metrics_list + final_metrics_list, |
|
["hparams/log_label",] |
|
+ best_metrics_name_list |
|
+ final_metrics_name_list, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_scores(self, loss, divisor, loss_name, all_outputs, all_labels): |
|
|
|
avg_loss = loss / divisor |
|
|
|
metric_list = [ |
|
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric]) |
|
for metric in self.metrics |
|
] |
|
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics] |
|
|
|
return dict(zip([loss_name,] + metric_name_list, [avg_loss,] + metric_list,)) |
|
|
|
def check_best(self, val_scores, save_on_score, best_score, global_step): |
|
save_flag = 0 |
|
best_step = global_step |
|
if self.train_config.save_on.desired == "min": |
|
if val_scores[save_on_score] < best_score: |
|
save_flag = 1 |
|
best_score = val_scores[save_on_score] |
|
best_step = global_step |
|
else: |
|
if val_scores[save_on_score] > best_score: |
|
save_flag = 1 |
|
best_score = val_scores[save_on_score] |
|
best_step = global_step |
|
return best_score, best_step, save_flag |
|
|
|
def update_hparams(self, train_scores, val_scores, desc): |
|
hparam_list = [] |
|
hparam_name_list = [] |
|
for hparam in self.train_config.log.values.hparams: |
|
hparam_list.append(get_item_in_config(self._config, hparam["path"])) |
|
if isinstance(hparam_list[-1], Config): |
|
hparam_list[-1] = hparam_list[-1].as_dict() |
|
hparam_name_list.append(hparam["name"]) |
|
|
|
val_keys, val_values = zip(*val_scores.items()) |
|
train_keys, train_values = zip(*train_scores.items()) |
|
val_keys = list(val_keys) |
|
train_keys = list(train_keys) |
|
val_values = list(val_values) |
|
train_values = list(train_values) |
|
for i, key in enumerate(val_keys): |
|
val_keys[i] = f"hparams/{desc}_val_" + val_keys[i] |
|
for i, key in enumerate(train_keys): |
|
train_keys[i] = f"hparams/{desc}_train_" + train_keys[i] |
|
|
|
return ( |
|
hparam_list, |
|
hparam_name_list, |
|
train_values + val_values, |
|
train_keys + val_keys, |
|
) |
|
|
|
def save(self, store_dict, path, save_flag=0): |
|
if save_flag: |
|
dirs = "/".join(path.split("/")[:-1]) |
|
if not os.path.exists(dirs): |
|
os.makedirs(dirs) |
|
torch.save(store_dict, path) |
|
|
|
def log( |
|
self, |
|
loss, |
|
loss_name, |
|
metric_list, |
|
metric_name_list, |
|
logger, |
|
log_values, |
|
global_step, |
|
append_text, |
|
): |
|
|
|
return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,)) |
|
|
|
loss_name = f"{append_text}_{self.log_label}_{loss_name}" |
|
if log_values["loss"]: |
|
logger.save_params( |
|
[loss], |
|
[loss_name], |
|
combine=True, |
|
combine_name="losses", |
|
global_step=global_step, |
|
) |
|
|
|
for i in range(len(metric_name_list)): |
|
metric_name_list[ |
|
i |
|
] = f"{append_text}_{self.log_label}_{metric_name_list[i]}" |
|
if log_values["metrics"]: |
|
logger.save_params( |
|
metric_list, |
|
metric_name_list, |
|
combine=True, |
|
combine_name="metrics", |
|
global_step=global_step, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return return_dic |
|
|
|
def val( |
|
self, |
|
model, |
|
dataset, |
|
criterion, |
|
device, |
|
global_step, |
|
train_logger=None, |
|
train_log_values=None, |
|
log=True, |
|
): |
|
append_text = self.val_config.append_text |
|
if train_logger is not None: |
|
val_logger = train_logger |
|
else: |
|
val_logger = Logger(**self.val_config.log.logger_params.as_dict()) |
|
|
|
if train_log_values is not None: |
|
val_log_values = train_log_values |
|
else: |
|
val_log_values = self.val_config.log.values.as_dict() |
|
if "custom_collate_fn" in dir(dataset): |
|
val_loader = DataLoader( |
|
dataset=dataset, |
|
collate_fn=dataset.custom_collate_fn, |
|
**self.val_config.loader_params.as_dict(), |
|
) |
|
else: |
|
val_loader = DataLoader( |
|
dataset=dataset, **self.val_config.loader_params.as_dict() |
|
) |
|
|
|
all_outputs = torch.Tensor().to(device) |
|
if(self.train_config.label_type=='float'): |
|
all_labels = torch.FloatTensor().to(device) |
|
else: |
|
all_labels = torch.LongTensor().to(device) |
|
|
|
batch_size = self.val_config.loader_params.batch_size |
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
val_loss = 0 |
|
for j, batch in enumerate(val_loader): |
|
|
|
inputs, labels = batch |
|
|
|
if(self.train_config.label_type=='float'): |
|
labels = labels.float() |
|
|
|
for key in inputs: |
|
inputs[key] = inputs[key].to(device) |
|
labels = labels.to(device) |
|
|
|
outputs = model(inputs) |
|
loss = criterion(torch.squeeze(outputs), labels) |
|
val_loss += loss.item() |
|
|
|
all_labels = torch.cat((all_labels, labels), 0) |
|
|
|
if (self.train_config.label_type=='float'): |
|
all_outputs = torch.cat((all_outputs, outputs), 0) |
|
else: |
|
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0) |
|
|
|
val_loss = val_loss / len(val_loader) |
|
|
|
val_loss_name = self.train_config.criterion.type |
|
|
|
|
|
metric_list = [ |
|
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric]) |
|
for metric in self.metrics |
|
] |
|
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics] |
|
return_dic = dict( |
|
zip([val_loss_name,] + metric_name_list, [val_loss,] + metric_list,) |
|
) |
|
if log: |
|
val_scores = self.log( |
|
val_loss, |
|
val_loss_name, |
|
metric_list, |
|
metric_name_list, |
|
val_logger, |
|
val_log_values, |
|
global_step, |
|
append_text, |
|
) |
|
return val_scores |
|
return return_dic |
|
|