Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii Inc. All rights reserved. | |
import inspect | |
import os | |
import sys | |
from loguru import logger | |
import torch | |
def get_caller_name(depth=0): | |
""" | |
Args: | |
depth (int): Depth of caller conext, use 0 for caller depth. | |
Default value: 0. | |
Returns: | |
str: module name of the caller | |
""" | |
# the following logic is a little bit faster than inspect.stack() logic | |
frame = inspect.currentframe().f_back | |
for _ in range(depth): | |
frame = frame.f_back | |
return frame.f_globals["__name__"] | |
class StreamToLoguru: | |
""" | |
stream object that redirects writes to a logger instance. | |
""" | |
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): | |
""" | |
Args: | |
level(str): log level string of loguru. Default value: "INFO". | |
caller_names(tuple): caller names of redirected module. | |
Default value: (apex, pycocotools). | |
""" | |
self.level = level | |
self.linebuf = "" | |
self.caller_names = caller_names | |
def write(self, buf): | |
full_name = get_caller_name(depth=1) | |
module_name = full_name.rsplit(".", maxsplit=-1)[0] | |
if module_name in self.caller_names: | |
for line in buf.rstrip().splitlines(): | |
# use caller level log | |
logger.opt(depth=2).log(self.level, line.rstrip()) | |
else: | |
sys.__stdout__.write(buf) | |
def flush(self): | |
pass | |
def redirect_sys_output(log_level="INFO"): | |
redirect_logger = StreamToLoguru(log_level) | |
sys.stderr = redirect_logger | |
sys.stdout = redirect_logger | |
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): | |
"""setup logger for training and testing. | |
Args: | |
save_dir(str): location to save log file | |
distributed_rank(int): device rank when multi-gpu environment | |
filename (string): log save name. | |
mode(str): log file write mode, `append` or `override`. default is `a`. | |
Return: | |
logger instance. | |
""" | |
loguru_format = ( | |
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " | |
"<level>{level: <8}</level> | " | |
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" | |
) | |
logger.remove() | |
save_file = os.path.join(save_dir, filename) | |
if mode == "o" and os.path.exists(save_file): | |
os.remove(save_file) | |
# only keep logger in rank0 process | |
if distributed_rank == 0: | |
logger.add( | |
sys.stderr, | |
format=loguru_format, | |
level="INFO", | |
enqueue=True, | |
) | |
logger.add(save_file) | |
# redirect stdout/stderr to loguru | |
redirect_sys_output("INFO") | |
class WandbLogger(object): | |
""" | |
Log training runs, datasets, models, and predictions to Weights & Biases. | |
This logger sends information to W&B at wandb.ai. | |
By default, this information includes hyperparameters, | |
system configuration and metrics, model metrics, | |
and basic data metrics and analyses. | |
For more information, please refer to: | |
https://docs.wandb.ai/guides/track | |
""" | |
def __init__(self, | |
project=None, | |
name=None, | |
id=None, | |
entity=None, | |
save_dir=None, | |
config=None, | |
**kwargs): | |
""" | |
Args: | |
project (str): wandb project name. | |
name (str): wandb run name. | |
id (str): wandb run id. | |
entity (str): wandb entity name. | |
save_dir (str): save directory. | |
config (dict): config dict. | |
**kwargs: other kwargs. | |
""" | |
try: | |
import wandb | |
self.wandb = wandb | |
except ModuleNotFoundError: | |
raise ModuleNotFoundError( | |
"wandb is not installed." | |
"Please install wandb using pip install wandb" | |
) | |
self.project = project | |
self.name = name | |
self.id = id | |
self.save_dir = save_dir | |
self.config = config | |
self.kwargs = kwargs | |
self.entity = entity | |
self._run = None | |
self._wandb_init = dict( | |
project=self.project, | |
name=self.name, | |
id=self.id, | |
entity=self.entity, | |
dir=self.save_dir, | |
resume="allow" | |
) | |
self._wandb_init.update(**kwargs) | |
_ = self.run | |
if self.config: | |
self.run.config.update(self.config) | |
self.run.define_metric("epoch") | |
self.run.define_metric("val/", step_metric="epoch") | |
def run(self): | |
if self._run is None: | |
if self.wandb.run is not None: | |
logger.info( | |
"There is a wandb run already in progress " | |
"and newly created instances of `WandbLogger` will reuse" | |
" this run. If this is not desired, call `wandb.finish()`" | |
"before instantiating `WandbLogger`." | |
) | |
self._run = self.wandb.run | |
else: | |
self._run = self.wandb.init(**self._wandb_init) | |
return self._run | |
def log_metrics(self, metrics, step=None): | |
""" | |
Args: | |
metrics (dict): metrics dict. | |
step (int): step number. | |
""" | |
for k, v in metrics.items(): | |
if isinstance(v, torch.Tensor): | |
metrics[k] = v.item() | |
if step is not None: | |
self.run.log(metrics, step=step) | |
else: | |
self.run.log(metrics) | |
def save_checkpoint(self, save_dir, model_name, is_best): | |
""" | |
Args: | |
save_dir (str): save directory. | |
model_name (str): model name. | |
is_best (bool): whether the model is the best model. | |
""" | |
filename = os.path.join(save_dir, model_name + "_ckpt.pth") | |
artifact = self.wandb.Artifact( | |
name=f"model-{self.run.id}", | |
type="model" | |
) | |
artifact.add_file(filename, name="model_ckpt.pth") | |
aliases = ["latest"] | |
if is_best: | |
aliases.append("best") | |
self.run.log_artifact(artifact, aliases=aliases) | |
def finish(self): | |
self.run.finish() | |