|  | import os | 
					
						
						|  | from dataclasses import dataclass, field | 
					
						
						|  |  | 
					
						
						|  | import pytorch_lightning as pl | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | import craftsman | 
					
						
						|  | from craftsman.utils.base import ( | 
					
						
						|  | Updateable, | 
					
						
						|  | update_end_if_possible, | 
					
						
						|  | update_if_possible, | 
					
						
						|  | ) | 
					
						
						|  | from craftsman.utils.scheduler import parse_optimizer, parse_scheduler | 
					
						
						|  | from craftsman.utils.config import parse_structured | 
					
						
						|  | from craftsman.utils.misc import C, cleanup, get_device, load_module_weights | 
					
						
						|  | from craftsman.utils.saving import SaverMixin | 
					
						
						|  | from craftsman.utils.typing import * | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BaseSystem(pl.LightningModule, Updateable, SaverMixin): | 
					
						
						|  | @dataclass | 
					
						
						|  | class Config: | 
					
						
						|  | loggers: dict = field(default_factory=dict) | 
					
						
						|  | loss: dict = field(default_factory=dict) | 
					
						
						|  | optimizer: dict = field(default_factory=dict) | 
					
						
						|  | scheduler: Optional[dict] = None | 
					
						
						|  | weights: Optional[str] = None | 
					
						
						|  | weights_ignore_modules: Optional[List[str]] = None | 
					
						
						|  | cleanup_after_validation_step: bool = False | 
					
						
						|  | cleanup_after_test_step: bool = False | 
					
						
						|  |  | 
					
						
						|  | pretrained_model_path: Optional[str] = None | 
					
						
						|  | strict_load: bool = True | 
					
						
						|  | cfg: Config | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, cfg, resumed=False) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.cfg = parse_structured(self.Config, cfg) | 
					
						
						|  | self._save_dir: Optional[str] = None | 
					
						
						|  | self._resumed: bool = resumed | 
					
						
						|  | self._resumed_eval: bool = False | 
					
						
						|  | self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} | 
					
						
						|  | if "loggers" in cfg: | 
					
						
						|  | self.create_loggers(cfg.loggers) | 
					
						
						|  |  | 
					
						
						|  | self.configure() | 
					
						
						|  | if self.cfg.weights is not None: | 
					
						
						|  | self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) | 
					
						
						|  | self.post_configure() | 
					
						
						|  |  | 
					
						
						|  | def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): | 
					
						
						|  | state_dict, epoch, global_step = load_module_weights( | 
					
						
						|  | weights, ignore_modules=ignore_modules, map_location="cpu" | 
					
						
						|  | ) | 
					
						
						|  | self.load_state_dict(state_dict, strict=False) | 
					
						
						|  |  | 
					
						
						|  | self.do_update_step(epoch, global_step, on_load_weights=True) | 
					
						
						|  |  | 
					
						
						|  | def set_resume_status(self, current_epoch: int, global_step: int): | 
					
						
						|  |  | 
					
						
						|  | self._resumed_eval = True | 
					
						
						|  | self._resumed_eval_status["current_epoch"] = current_epoch | 
					
						
						|  | self._resumed_eval_status["global_step"] = global_step | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def resumed(self): | 
					
						
						|  |  | 
					
						
						|  | return self._resumed | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def true_global_step(self): | 
					
						
						|  | if self._resumed_eval: | 
					
						
						|  | return self._resumed_eval_status["global_step"] | 
					
						
						|  | else: | 
					
						
						|  | return self.global_step | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def true_current_epoch(self): | 
					
						
						|  | if self._resumed_eval: | 
					
						
						|  | return self._resumed_eval_status["current_epoch"] | 
					
						
						|  | else: | 
					
						
						|  | return self.current_epoch | 
					
						
						|  |  | 
					
						
						|  | def configure(self) -> None: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def post_configure(self) -> None: | 
					
						
						|  | """ | 
					
						
						|  | executed after weights are loaded | 
					
						
						|  | """ | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def C(self, value: Any) -> float: | 
					
						
						|  | return C(value, self.true_current_epoch, self.true_global_step) | 
					
						
						|  |  | 
					
						
						|  | def configure_optimizers(self): | 
					
						
						|  | optim = parse_optimizer(self.cfg.optimizer, self) | 
					
						
						|  | ret = { | 
					
						
						|  | "optimizer": optim, | 
					
						
						|  | } | 
					
						
						|  | if self.cfg.scheduler is not None: | 
					
						
						|  | ret.update( | 
					
						
						|  | { | 
					
						
						|  | "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  | return ret | 
					
						
						|  |  | 
					
						
						|  | def training_step(self, batch, batch_idx): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def validation_step(self, batch, batch_idx): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def on_train_batch_end(self, outputs, batch, batch_idx): | 
					
						
						|  | self.dataset = self.trainer.train_dataloader.dataset | 
					
						
						|  | update_end_if_possible( | 
					
						
						|  | self.dataset, self.true_current_epoch, self.true_global_step | 
					
						
						|  | ) | 
					
						
						|  | self.do_update_step_end(self.true_current_epoch, self.true_global_step) | 
					
						
						|  |  | 
					
						
						|  | def on_validation_batch_end(self, outputs, batch, batch_idx): | 
					
						
						|  | self.dataset = self.trainer.val_dataloaders.dataset | 
					
						
						|  | update_end_if_possible( | 
					
						
						|  | self.dataset, self.true_current_epoch, self.true_global_step | 
					
						
						|  | ) | 
					
						
						|  | self.do_update_step_end(self.true_current_epoch, self.true_global_step) | 
					
						
						|  | if self.cfg.cleanup_after_validation_step: | 
					
						
						|  |  | 
					
						
						|  | cleanup() | 
					
						
						|  |  | 
					
						
						|  | def on_validation_epoch_end(self): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def test_step(self, batch, batch_idx): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def on_test_batch_end(self, outputs, batch, batch_idx): | 
					
						
						|  | self.dataset = self.trainer.test_dataloaders.dataset | 
					
						
						|  | update_end_if_possible( | 
					
						
						|  | self.dataset, self.true_current_epoch, self.true_global_step | 
					
						
						|  | ) | 
					
						
						|  | self.do_update_step_end(self.true_current_epoch, self.true_global_step) | 
					
						
						|  | if self.cfg.cleanup_after_test_step: | 
					
						
						|  |  | 
					
						
						|  | cleanup() | 
					
						
						|  |  | 
					
						
						|  | def on_test_epoch_end(self): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def predict_step(self, batch, batch_idx): | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def on_predict_batch_end(self, outputs, batch, batch_idx): | 
					
						
						|  | self.dataset = self.trainer.predict_dataloaders.dataset | 
					
						
						|  | update_end_if_possible( | 
					
						
						|  | self.dataset, self.true_current_epoch, self.true_global_step | 
					
						
						|  | ) | 
					
						
						|  | self.do_update_step_end(self.true_current_epoch, self.true_global_step) | 
					
						
						|  | if self.cfg.cleanup_after_test_step: | 
					
						
						|  |  | 
					
						
						|  | cleanup() | 
					
						
						|  |  | 
					
						
						|  | def on_predict_epoch_end(self): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def preprocess_data(self, batch, stage): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | Implementing on_after_batch_transfer of DataModule does the same. | 
					
						
						|  | But on_after_batch_transfer does not support DP. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def on_train_batch_start(self, batch, batch_idx, unused=0): | 
					
						
						|  | self.preprocess_data(batch, "train") | 
					
						
						|  | self.dataset = self.trainer.train_dataloader.dataset | 
					
						
						|  | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | 
					
						
						|  | self.do_update_step(self.true_current_epoch, self.true_global_step) | 
					
						
						|  |  | 
					
						
						|  | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): | 
					
						
						|  | self.preprocess_data(batch, "validation") | 
					
						
						|  | self.dataset = self.trainer.val_dataloaders.dataset | 
					
						
						|  | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | 
					
						
						|  | self.do_update_step(self.true_current_epoch, self.true_global_step) | 
					
						
						|  |  | 
					
						
						|  | def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): | 
					
						
						|  | self.preprocess_data(batch, "test") | 
					
						
						|  | self.dataset = self.trainer.test_dataloaders.dataset | 
					
						
						|  | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | 
					
						
						|  | self.do_update_step(self.true_current_epoch, self.true_global_step) | 
					
						
						|  |  | 
					
						
						|  | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): | 
					
						
						|  | self.preprocess_data(batch, "predict") | 
					
						
						|  | self.dataset = self.trainer.predict_dataloaders.dataset | 
					
						
						|  | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) | 
					
						
						|  | self.do_update_step(self.true_current_epoch, self.true_global_step) | 
					
						
						|  |  | 
					
						
						|  | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def on_before_optimizer_step(self, optimizer): | 
					
						
						|  | """ | 
					
						
						|  | # some gradient-related debugging goes here, example: | 
					
						
						|  | from lightning.pytorch.utilities import grad_norm | 
					
						
						|  | norms = grad_norm(self.geometry, norm_type=2) | 
					
						
						|  | print(norms) | 
					
						
						|  | """ | 
					
						
						|  | pass | 
					
						
						|  |  |