Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import logging | |
| import os | |
| import sys | |
| import traceback | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| os.environ['OPENBLAS_NUM_THREADS'] = '1' | |
| os.environ['MKL_NUM_THREADS'] = '1' | |
| os.environ['VECLIB_MAXIMUM_THREADS'] = '1' | |
| os.environ['NUMEXPR_NUM_THREADS'] = '1' | |
| import hydra | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from pytorch_lightning.plugins import DDPPlugin | |
| from saicinpainting.training.trainers import make_training_model | |
| from saicinpainting.utils import register_debug_signal_handlers, handle_ddp_subprocess, handle_ddp_parent_process, \ | |
| handle_deterministic_config | |
| LOGGER = logging.getLogger(__name__) | |
| def main(config: OmegaConf): | |
| try: | |
| need_set_deterministic = handle_deterministic_config(config) | |
| register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log | |
| is_in_ddp_subprocess = handle_ddp_parent_process() | |
| config.visualizer.outdir = os.path.join(os.getcwd(), config.visualizer.outdir) | |
| if not is_in_ddp_subprocess: | |
| LOGGER.info(OmegaConf.to_yaml(config)) | |
| OmegaConf.save(config, os.path.join(os.getcwd(), 'config.yaml')) | |
| checkpoints_dir = os.path.join(os.getcwd(), 'models') | |
| os.makedirs(checkpoints_dir, exist_ok=True) | |
| # there is no need to suppress this logger in ddp, because it handles rank on its own | |
| metrics_logger = TensorBoardLogger(config.location.tb_dir, name=os.path.basename(os.getcwd())) | |
| metrics_logger.log_hyperparams(config) | |
| training_model = make_training_model(config) | |
| trainer_kwargs = OmegaConf.to_container(config.trainer.kwargs, resolve=True) | |
| if need_set_deterministic: | |
| trainer_kwargs['deterministic'] = True | |
| trainer = Trainer( | |
| # there is no need to suppress checkpointing in ddp, because it handles rank on its own | |
| callbacks=ModelCheckpoint(dirpath=checkpoints_dir, **config.trainer.checkpoint_kwargs), | |
| logger=metrics_logger, | |
| default_root_dir=os.getcwd(), | |
| **trainer_kwargs | |
| ) | |
| trainer.fit(training_model) | |
| except KeyboardInterrupt: | |
| LOGGER.warning('Interrupted by user') | |
| except Exception as ex: | |
| LOGGER.critical(f'Training failed due to {ex}:\n{traceback.format_exc()}') | |
| sys.exit(1) | |
| if __name__ == '__main__': | |
| main() | |