xlm-roberta-large / src /trainers /base_trainer.py
shayekh's picture
Upload 61 files
cc9c7ee
raw
history blame
20.6 kB
import math
import os
import torch
from src.modules.optimizers import *
from src.modules.embeddings import *
from src.modules.schedulers import *
from src.modules.tokenizers import *
from src.modules.metrics import *
from src.modules.losses import *
from src.utils.misc import *
from src.utils.logger import Logger
from src.utils.mapper import configmapper
from src.utils.configuration import Config
from torch.utils.data import DataLoader
from tqdm import tqdm
@configmapper.map("trainers", "base")
class BaseTrainer:
def __init__(self, config):
self._config = config
self.metrics = {
configmapper.get_object("metrics", metric["type"]): metric["params"]
for metric in self._config.main_config.metrics
}
self.train_config = self._config.train
self.val_config = self._config.val
self.log_label = self.train_config.log.log_label
if self.train_config.log_and_val_interval is not None:
self.val_log_together = True
print("Logging with label: ", self.log_label)
def train(self, model, train_dataset, val_dataset=None, logger=None):
device = torch.device(self._config.main_config.device.name)
model.to(device)
optim_params = self.train_config.optimizer.params
if optim_params:
optimizer = configmapper.get_object(
"optimizers", self.train_config.optimizer.type
)(model.parameters(), **optim_params.as_dict())
else:
optimizer = configmapper.get_object(
"optimizers", self.train_config.optimizer.type
)(model.parameters())
if self.train_config.scheduler is not None:
scheduler_params = self.train_config.scheduler.params
if scheduler_params:
scheduler = configmapper.get_object(
"schedulers", self.train_config.scheduler.type
)(optimizer, **scheduler_params.as_dict())
else:
scheduler = configmapper.get_object(
"schedulers", self.train_config.scheduler.type
)(optimizer)
criterion_params = self.train_config.criterion.params
if criterion_params:
criterion = configmapper.get_object(
"losses", self.train_config.criterion.type
)(**criterion_params.as_dict())
else:
criterion = configmapper.get_object(
"losses", self.train_config.criterion.type
)()
if "custom_collate_fn" in dir(train_dataset):
train_loader = DataLoader(
dataset=train_dataset,
collate_fn=train_dataset.custom_collate_fn,
**self.train_config.loader_params.as_dict(),
)
else:
train_loader = DataLoader(
dataset=train_dataset, **self.train_config.loader_params.as_dict()
)
# train_logger = Logger(**self.train_config.log.logger_params.as_dict())
max_epochs = self.train_config.max_epochs
batch_size = self.train_config.loader_params.batch_size
if self.val_log_together:
val_interval = self.train_config.log_and_val_interval
log_interval = val_interval
else:
val_interval = self.train_config.val_interval
log_interval = self.train_config.log.log_interval
if logger is None:
train_logger = Logger(**self.train_config.log.logger_params.as_dict())
else:
train_logger = logger
train_log_values = self.train_config.log.values.as_dict()
best_score = (
-math.inf if self.train_config.save_on.desired == "max" else math.inf
)
save_on_score = self.train_config.save_on.score
best_step = -1
best_model = None
best_hparam_list = None
best_hparam_name_list = None
best_metrics_list = None
best_metrics_name_list = None
# print("\nTraining\n")
# print(max_steps)
global_step = 0
for epoch in range(1, max_epochs + 1):
print(
"Epoch: {}/{}, Global Step: {}".format(epoch, max_epochs, global_step)
)
train_loss = 0
val_loss = 0
if(self.train_config.label_type=='float'):
all_labels = torch.FloatTensor().to(device)
else:
all_labels = torch.LongTensor().to(device)
all_outputs = torch.Tensor().to(device)
train_scores = None
val_scores = None
pbar = tqdm(total=math.ceil(len(train_dataset) / batch_size))
pbar.set_description("Epoch " + str(epoch))
val_counter = 0
for step, batch in enumerate(train_loader):
model.train()
optimizer.zero_grad()
inputs, labels = batch
if(self.train_config.label_type=='float'): ##Specific to Float Type
labels = labels.float()
for key in inputs:
inputs[key] = inputs[key].to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(torch.squeeze(outputs), labels)
loss.backward()
all_labels = torch.cat((all_labels, labels), 0)
if (self.train_config.label_type=='float'):
all_outputs = torch.cat((all_outputs, outputs), 0)
else:
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
train_loss += loss.item()
optimizer.step()
if self.train_config.scheduler is not None:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(train_loss / (step + 1))
else:
scheduler.step()
# print(train_loss)
# print(step+1)
pbar.set_postfix_str(f"Train Loss: {train_loss /(step+1)}")
pbar.update(1)
global_step += 1
# Need to check if we want global_step or local_step
if val_dataset is not None and (global_step - 1) % val_interval == 0:
# print("\nEvaluating\n")
val_scores = self.val(
model,
val_dataset,
criterion,
device,
global_step,
train_logger,
train_log_values,
)
#save_flag = 0
if self.train_config.save_on is not None:
## BEST SCORES UPDATING
train_scores = self.get_scores(
train_loss,
global_step,
self.train_config.criterion.type,
all_outputs,
all_labels,
)
best_score, best_step, save_flag = self.check_best(
val_scores, save_on_score, best_score, global_step
)
store_dict = {
"model_state_dict": model.state_dict(),
"best_step": best_step,
"best_score": best_score,
"save_on_score": save_on_score,
}
path = self.train_config.save_on.best_path.format(
self.log_label
)
self.save(store_dict, path, save_flag)
if save_flag and train_log_values["hparams"] is not None:
(
best_hparam_list,
best_hparam_name_list,
best_metrics_list,
best_metrics_name_list,
) = self.update_hparams(
train_scores, val_scores, desc="best_val"
)
# pbar.close()
if (global_step - 1) % log_interval == 0:
# print("\nLogging\n")
train_loss_name = self.train_config.criterion.type
metric_list = [
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
for metric in self.metrics
]
metric_name_list = [
metric['type'] for metric in self._config.main_config.metrics
]
train_scores = self.log(
train_loss / (step + 1),
train_loss_name,
metric_list,
metric_name_list,
train_logger,
train_log_values,
global_step,
append_text=self.train_config.append_text,
)
pbar.close()
if not os.path.exists(self.train_config.checkpoint.checkpoint_dir):
os.makedirs(self.train_config.checkpoint.checkpoint_dir)
if self.train_config.save_after_epoch:
store_dict = {
"model_state_dict": model.state_dict(),
}
path = f"{self.train_config.checkpoint.checkpoint_dir}_{str(self.train_config.log.log_label)}_{str(epoch)}.pth"
self.save(store_dict, path, save_flag=1)
if epoch == max_epochs:
# print("\nEvaluating\n")
val_scores = self.val(
model,
val_dataset,
criterion,
device,
global_step,
train_logger,
train_log_values,
)
# print("\nLogging\n")
train_loss_name = self.train_config.criterion.type
metric_list = [
metric(all_labels.cpu(), all_outputs.detach().cpu(),**self.metrics[metric])
for metric in self.metrics
]
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
train_scores = self.log(
train_loss / len(train_loader),
train_loss_name,
metric_list,
metric_name_list,
train_logger,
train_log_values,
global_step,
append_text=self.train_config.append_text,
)
if self.train_config.save_on is not None:
## BEST SCORES UPDATING
train_scores = self.get_scores(
train_loss,
len(train_loader),
self.train_config.criterion.type,
all_outputs,
all_labels,
)
best_score, best_step, save_flag = self.check_best(
val_scores, save_on_score, best_score, global_step
)
store_dict = {
"model_state_dict": model.state_dict(),
"best_step": best_step,
"best_score": best_score,
"save_on_score": save_on_score,
}
path = self.train_config.save_on.best_path.format(self.log_label)
self.save(store_dict, path, save_flag)
if save_flag and train_log_values["hparams"] is not None:
(
best_hparam_list,
best_hparam_name_list,
best_metrics_list,
best_metrics_name_list,
) = self.update_hparams(train_scores, val_scores, desc="best_val")
## FINAL SCORES UPDATING + STORING
train_scores = self.get_scores(
train_loss,
len(train_loader),
self.train_config.criterion.type,
all_outputs,
all_labels,
)
store_dict = {
"model_state_dict": model.state_dict(),
"final_step": global_step,
"final_score": train_scores[save_on_score],
"save_on_score": save_on_score,
}
path = self.train_config.save_on.final_path.format(self.log_label)
self.save(store_dict, path, save_flag=1)
if train_log_values["hparams"] is not None:
(
final_hparam_list,
final_hparam_name_list,
final_metrics_list,
final_metrics_name_list,
) = self.update_hparams(train_scores, val_scores, desc="final")
train_logger.save_hyperparams(
best_hparam_list,
best_hparam_name_list,
[int(self.log_label),] + best_metrics_list + final_metrics_list,
["hparams/log_label",]
+ best_metrics_name_list
+ final_metrics_name_list,
)
#
## Need to check if we want same loggers of different loggers for train and eval
## Evaluate
def get_scores(self, loss, divisor, loss_name, all_outputs, all_labels):
avg_loss = loss / divisor
metric_list = [
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
for metric in self.metrics
]
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
return dict(zip([loss_name,] + metric_name_list, [avg_loss,] + metric_list,))
def check_best(self, val_scores, save_on_score, best_score, global_step):
save_flag = 0
best_step = global_step
if self.train_config.save_on.desired == "min":
if val_scores[save_on_score] < best_score:
save_flag = 1
best_score = val_scores[save_on_score]
best_step = global_step
else:
if val_scores[save_on_score] > best_score:
save_flag = 1
best_score = val_scores[save_on_score]
best_step = global_step
return best_score, best_step, save_flag
def update_hparams(self, train_scores, val_scores, desc):
hparam_list = []
hparam_name_list = []
for hparam in self.train_config.log.values.hparams:
hparam_list.append(get_item_in_config(self._config, hparam["path"]))
if isinstance(hparam_list[-1], Config):
hparam_list[-1] = hparam_list[-1].as_dict()
hparam_name_list.append(hparam["name"])
val_keys, val_values = zip(*val_scores.items())
train_keys, train_values = zip(*train_scores.items())
val_keys = list(val_keys)
train_keys = list(train_keys)
val_values = list(val_values)
train_values = list(train_values)
for i, key in enumerate(val_keys):
val_keys[i] = f"hparams/{desc}_val_" + val_keys[i]
for i, key in enumerate(train_keys):
train_keys[i] = f"hparams/{desc}_train_" + train_keys[i]
# train_logger.save_hyperparams(hparam_list, hparam_name_list,train_values+val_values,train_keys+val_keys, )
return (
hparam_list,
hparam_name_list,
train_values + val_values,
train_keys + val_keys,
)
def save(self, store_dict, path, save_flag=0):
if save_flag:
dirs = "/".join(path.split("/")[:-1])
if not os.path.exists(dirs):
os.makedirs(dirs)
torch.save(store_dict, path)
def log(
self,
loss,
loss_name,
metric_list,
metric_name_list,
logger,
log_values,
global_step,
append_text,
):
return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,))
loss_name = f"{append_text}_{self.log_label}_{loss_name}"
if log_values["loss"]:
logger.save_params(
[loss],
[loss_name],
combine=True,
combine_name="losses",
global_step=global_step,
)
for i in range(len(metric_name_list)):
metric_name_list[
i
] = f"{append_text}_{self.log_label}_{metric_name_list[i]}"
if log_values["metrics"]:
logger.save_params(
metric_list,
metric_name_list,
combine=True,
combine_name="metrics",
global_step=global_step,
)
# print(hparams_list)
# print(hparam_name_list)
# for k,v in dict(zip([loss_name],[loss])).items():
# print(f"{k}:{v}")
# for k,v in dict(zip(metric_name_list,metric_list)).items():
# print(f"{k}:{v}")
return return_dic
def val(
self,
model,
dataset,
criterion,
device,
global_step,
train_logger=None,
train_log_values=None,
log=True,
):
append_text = self.val_config.append_text
if train_logger is not None:
val_logger = train_logger
else:
val_logger = Logger(**self.val_config.log.logger_params.as_dict())
if train_log_values is not None:
val_log_values = train_log_values
else:
val_log_values = self.val_config.log.values.as_dict()
if "custom_collate_fn" in dir(dataset):
val_loader = DataLoader(
dataset=dataset,
collate_fn=dataset.custom_collate_fn,
**self.val_config.loader_params.as_dict(),
)
else:
val_loader = DataLoader(
dataset=dataset, **self.val_config.loader_params.as_dict()
)
all_outputs = torch.Tensor().to(device)
if(self.train_config.label_type=='float'):
all_labels = torch.FloatTensor().to(device)
else:
all_labels = torch.LongTensor().to(device)
batch_size = self.val_config.loader_params.batch_size
with torch.no_grad():
model.eval()
val_loss = 0
for j, batch in enumerate(val_loader):
inputs, labels = batch
if(self.train_config.label_type=='float'):
labels = labels.float()
for key in inputs:
inputs[key] = inputs[key].to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(torch.squeeze(outputs), labels)
val_loss += loss.item()
all_labels = torch.cat((all_labels, labels), 0)
if (self.train_config.label_type=='float'):
all_outputs = torch.cat((all_outputs, outputs), 0)
else:
all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0)
val_loss = val_loss / len(val_loader)
val_loss_name = self.train_config.criterion.type
# print(all_outputs, all_labels)
metric_list = [
metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric])
for metric in self.metrics
]
metric_name_list = [metric['type'] for metric in self._config.main_config.metrics]
return_dic = dict(
zip([val_loss_name,] + metric_name_list, [val_loss,] + metric_list,)
)
if log:
val_scores = self.log(
val_loss,
val_loss_name,
metric_list,
metric_name_list,
val_logger,
val_log_values,
global_step,
append_text,
)
return val_scores
return return_dic