|
|
|
|
|
|
|
|
|
import inspect |
|
import os |
|
import sys |
|
from loguru import logger |
|
|
|
import torch |
|
|
|
|
|
def get_caller_name(depth=0): |
|
""" |
|
Args: |
|
depth (int): Depth of caller conext, use 0 for caller depth. |
|
Default value: 0. |
|
|
|
Returns: |
|
str: module name of the caller |
|
""" |
|
|
|
frame = inspect.currentframe().f_back |
|
for _ in range(depth): |
|
frame = frame.f_back |
|
|
|
return frame.f_globals["__name__"] |
|
|
|
|
|
class StreamToLoguru: |
|
""" |
|
stream object that redirects writes to a logger instance. |
|
""" |
|
|
|
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): |
|
""" |
|
Args: |
|
level(str): log level string of loguru. Default value: "INFO". |
|
caller_names(tuple): caller names of redirected module. |
|
Default value: (apex, pycocotools). |
|
""" |
|
self.level = level |
|
self.linebuf = "" |
|
self.caller_names = caller_names |
|
|
|
def write(self, buf): |
|
full_name = get_caller_name(depth=1) |
|
module_name = full_name.rsplit(".", maxsplit=-1)[0] |
|
if module_name in self.caller_names: |
|
for line in buf.rstrip().splitlines(): |
|
|
|
logger.opt(depth=2).log(self.level, line.rstrip()) |
|
else: |
|
sys.__stdout__.write(buf) |
|
|
|
def flush(self): |
|
pass |
|
|
|
|
|
def redirect_sys_output(log_level="INFO"): |
|
redirect_logger = StreamToLoguru(log_level) |
|
sys.stderr = redirect_logger |
|
sys.stdout = redirect_logger |
|
|
|
|
|
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): |
|
"""setup logger for training and testing. |
|
Args: |
|
save_dir(str): location to save log file |
|
distributed_rank(int): device rank when multi-gpu environment |
|
filename (string): log save name. |
|
mode(str): log file write mode, `append` or `override`. default is `a`. |
|
|
|
Return: |
|
logger instance. |
|
""" |
|
loguru_format = ( |
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " |
|
"<level>{level: <8}</level> | " |
|
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" |
|
) |
|
|
|
logger.remove() |
|
save_file = os.path.join(save_dir, filename) |
|
if mode == "o" and os.path.exists(save_file): |
|
os.remove(save_file) |
|
|
|
if distributed_rank == 0: |
|
logger.add( |
|
sys.stderr, |
|
format=loguru_format, |
|
level="INFO", |
|
enqueue=True, |
|
) |
|
logger.add(save_file) |
|
|
|
|
|
redirect_sys_output("INFO") |
|
|
|
|
|
class WandbLogger(object): |
|
""" |
|
Log training runs, datasets, models, and predictions to Weights & Biases. |
|
This logger sends information to W&B at wandb.ai. |
|
By default, this information includes hyperparameters, |
|
system configuration and metrics, model metrics, |
|
and basic data metrics and analyses. |
|
|
|
For more information, please refer to: |
|
https://docs.wandb.ai/guides/track |
|
""" |
|
def __init__(self, |
|
project=None, |
|
name=None, |
|
id=None, |
|
entity=None, |
|
save_dir=None, |
|
config=None, |
|
**kwargs): |
|
""" |
|
Args: |
|
project (str): wandb project name. |
|
name (str): wandb run name. |
|
id (str): wandb run id. |
|
entity (str): wandb entity name. |
|
save_dir (str): save directory. |
|
config (dict): config dict. |
|
**kwargs: other kwargs. |
|
""" |
|
try: |
|
import wandb |
|
self.wandb = wandb |
|
except ModuleNotFoundError: |
|
raise ModuleNotFoundError( |
|
"wandb is not installed." |
|
"Please install wandb using pip install wandb" |
|
) |
|
|
|
self.project = project |
|
self.name = name |
|
self.id = id |
|
self.save_dir = save_dir |
|
self.config = config |
|
self.kwargs = kwargs |
|
self.entity = entity |
|
self._run = None |
|
self._wandb_init = dict( |
|
project=self.project, |
|
name=self.name, |
|
id=self.id, |
|
entity=self.entity, |
|
dir=self.save_dir, |
|
resume="allow" |
|
) |
|
self._wandb_init.update(**kwargs) |
|
|
|
_ = self.run |
|
|
|
if self.config: |
|
self.run.config.update(self.config) |
|
self.run.define_metric("epoch") |
|
self.run.define_metric("val/", step_metric="epoch") |
|
|
|
@property |
|
def run(self): |
|
if self._run is None: |
|
if self.wandb.run is not None: |
|
logger.info( |
|
"There is a wandb run already in progress " |
|
"and newly created instances of `WandbLogger` will reuse" |
|
" this run. If this is not desired, call `wandb.finish()`" |
|
"before instantiating `WandbLogger`." |
|
) |
|
self._run = self.wandb.run |
|
else: |
|
self._run = self.wandb.init(**self._wandb_init) |
|
return self._run |
|
|
|
def log_metrics(self, metrics, step=None): |
|
""" |
|
Args: |
|
metrics (dict): metrics dict. |
|
step (int): step number. |
|
""" |
|
|
|
for k, v in metrics.items(): |
|
if isinstance(v, torch.Tensor): |
|
metrics[k] = v.item() |
|
|
|
if step is not None: |
|
self.run.log(metrics, step=step) |
|
else: |
|
self.run.log(metrics) |
|
|
|
def save_checkpoint(self, save_dir, model_name, is_best): |
|
""" |
|
Args: |
|
save_dir (str): save directory. |
|
model_name (str): model name. |
|
is_best (bool): whether the model is the best model. |
|
""" |
|
filename = os.path.join(save_dir, model_name + "_ckpt.pth") |
|
artifact = self.wandb.Artifact( |
|
name=f"model-{self.run.id}", |
|
type="model" |
|
) |
|
artifact.add_file(filename, name="model_ckpt.pth") |
|
|
|
aliases = ["latest"] |
|
|
|
if is_best: |
|
aliases.append("best") |
|
|
|
self.run.log_artifact(artifact, aliases=aliases) |
|
|
|
def finish(self): |
|
self.run.finish() |
|
|