Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict | |
| import os.path | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT | |
| from torch.optim.lr_scheduler import LRScheduler | |
| from torch.optim import Optimizer | |
| from lightning.pytorch.callbacks import Callback | |
| from src.models.vae import BaseVAE, fp2uint8 | |
| from src.models.conditioner import BaseConditioner | |
| from src.utils.model_loader import ModelLoader | |
| from src.callbacks.simple_ema import SimpleEMA | |
| from src.diffusion.base.sampling import BaseSampler | |
| from src.diffusion.base.training import BaseTrainer | |
| from src.utils.no_grad import no_grad, filter_nograd_tensors | |
| from src.utils.copy import copy_params | |
| EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] | |
| OptimizerCallable = Callable[[Iterable], Optimizer] | |
| LRSchedulerCallable = Callable[[Optimizer], LRScheduler] | |
| class LightningModel(pl.LightningModule): | |
| def __init__(self, | |
| vae: BaseVAE, | |
| conditioner: BaseConditioner, | |
| denoiser: nn.Module, | |
| diffusion_trainer: BaseTrainer, | |
| diffusion_sampler: BaseSampler, | |
| ema_tracker: Optional[EMACallable] = None, | |
| optimizer: OptimizerCallable = None, | |
| lr_scheduler: LRSchedulerCallable = None, | |
| ): | |
| super().__init__() | |
| self.vae = vae | |
| self.conditioner = conditioner | |
| self.denoiser = denoiser | |
| self.ema_denoiser = copy.deepcopy(self.denoiser) | |
| self.diffusion_sampler = diffusion_sampler | |
| self.diffusion_trainer = diffusion_trainer | |
| self.ema_tracker = ema_tracker | |
| self.optimizer = optimizer | |
| self.lr_scheduler = lr_scheduler | |
| # self.model_loader = ModelLoader() | |
| self._strict_loading = False | |
| def configure_model(self) -> None: | |
| self.trainer.strategy.barrier() | |
| # self.denoiser = self.model_loader.load(self.denoiser) | |
| copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) | |
| # self.denoiser = torch.compile(self.denoiser) | |
| # disable grad for conditioner and vae | |
| no_grad(self.conditioner) | |
| no_grad(self.vae) | |
| no_grad(self.diffusion_sampler) | |
| no_grad(self.ema_denoiser) | |
| def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: | |
| ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser) | |
| return [ema_tracker] | |
| def configure_optimizers(self) -> OptimizerLRScheduler: | |
| params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) | |
| params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) | |
| optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser]) | |
| if self.lr_scheduler is None: | |
| return dict( | |
| optimizer=optimizer | |
| ) | |
| else: | |
| lr_scheduler = self.lr_scheduler(optimizer) | |
| return dict( | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler | |
| ) | |
| def training_step(self, batch, batch_idx): | |
| raw_images, x, y = batch | |
| with torch.no_grad(): | |
| x = self.vae.encode(x) | |
| condition, uncondition = self.conditioner(y) | |
| loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition) | |
| self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) | |
| return loss["loss"] | |
| def predict_step(self, batch, batch_idx): | |
| xT, y, metadata = batch | |
| with torch.no_grad(): | |
| condition, uncondition = self.conditioner(y) | |
| # Sample images: | |
| samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) | |
| samples = self.vae.decode(samples) | |
| # fp32 -1,1 -> uint8 0,255 | |
| samples = fp2uint8(samples) | |
| return samples | |
| def validation_step(self, batch, batch_idx): | |
| samples = self.predict_step(batch, batch_idx) | |
| return samples | |
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | |
| if destination is None: | |
| destination = {} | |
| self._save_to_state_dict(destination, prefix, keep_vars) | |
| self.denoiser.state_dict( | |
| destination=destination, | |
| prefix=prefix+"denoiser.", | |
| keep_vars=keep_vars) | |
| self.ema_denoiser.state_dict( | |
| destination=destination, | |
| prefix=prefix+"ema_denoiser.", | |
| keep_vars=keep_vars) | |
| self.diffusion_trainer.state_dict( | |
| destination=destination, | |
| prefix=prefix+"diffusion_trainer.", | |
| keep_vars=keep_vars) | |
| return destination |