Spaces:
Running
on
Zero
Running
on
Zero
| from .. import WarpCore | |
| from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary | |
| from abc import abstractmethod | |
| from dataclasses import dataclass | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| from gdf import GDF | |
| import numpy as np | |
| from tqdm import tqdm | |
| import wandb | |
| import webdataset as wds | |
| from webdataset.handlers import warn_and_continue | |
| from torch.distributed import barrier | |
| from enum import Enum | |
| class TargetReparametrization(Enum): | |
| EPSILON = 'epsilon' | |
| X0 = 'x0' | |
| class DiffusionCore(WarpCore): | |
| class Config(WarpCore.Config): | |
| # TRAINING PARAMS | |
| lr: float = EXPECTED_TRAIN | |
| grad_accum_steps: int = EXPECTED_TRAIN | |
| batch_size: int = EXPECTED_TRAIN | |
| updates: int = EXPECTED_TRAIN | |
| warmup_updates: int = EXPECTED_TRAIN | |
| save_every: int = 500 | |
| backup_every: int = 20000 | |
| use_fsdp: bool = True | |
| # EMA UPDATE | |
| ema_start_iters: int = None | |
| ema_iters: int = None | |
| ema_beta: float = None | |
| # GDF setting | |
| gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 | |
| # not frozen, means that fields are mutable. Doesn't support EXPECTED | |
| class Info(WarpCore.Info): | |
| ema_loss: float = None | |
| class Models(WarpCore.Models): | |
| generator : nn.Module = EXPECTED | |
| generator_ema : nn.Module = None # optional | |
| class Optimizers(WarpCore.Optimizers): | |
| generator : any = EXPECTED | |
| class Schedulers(WarpCore.Schedulers): | |
| generator: any = None | |
| class Extras(WarpCore.Extras): | |
| gdf: GDF = EXPECTED | |
| sampling_configs: dict = EXPECTED | |
| # -------------------------------------------- | |
| info: Info | |
| config: Config | |
| def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: | |
| raise NotImplementedError("This method needs to be overriden") | |
| def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: | |
| raise NotImplementedError("This method needs to be overriden") | |
| def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): | |
| raise NotImplementedError("This method needs to be overriden") | |
| def webdataset_path(self, extras: Extras): | |
| raise NotImplementedError("This method needs to be overriden") | |
| def webdataset_filters(self, extras: Extras): | |
| raise NotImplementedError("This method needs to be overriden") | |
| def webdataset_preprocessors(self, extras: Extras): | |
| raise NotImplementedError("This method needs to be overriden") | |
| def sample(self, models: Models, data: WarpCore.Data, extras: Extras): | |
| raise NotImplementedError("This method needs to be overriden") | |
| # ------------- | |
| def setup_data(self, extras: Extras) -> WarpCore.Data: | |
| # SETUP DATASET | |
| dataset_path = self.webdataset_path(extras) | |
| preprocessors = self.webdataset_preprocessors(extras) | |
| filters = self.webdataset_filters(extras) | |
| handler = warn_and_continue # None | |
| # handler = None | |
| dataset = wds.WebDataset( | |
| dataset_path, resampled=True, handler=handler | |
| ).select(filters).shuffle(690, handler=handler).decode( | |
| "pilrgb", handler=handler | |
| ).to_tuple( | |
| *[p[0] for p in preprocessors], handler=handler | |
| ).map_tuple( | |
| *[p[1] for p in preprocessors], handler=handler | |
| ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) | |
| # SETUP DATALOADER | |
| real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) | |
| dataloader = DataLoader( | |
| dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True | |
| ) | |
| return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) | |
| def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): | |
| batch = next(data.iterator) | |
| with torch.no_grad(): | |
| conditions = self.get_conditions(batch, models, extras) | |
| latents = self.encode_latents(batch, models, extras) | |
| noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) | |
| # FORWARD PASS | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| pred = models.generator(noised, noise_cond, **conditions) | |
| if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: | |
| pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss | |
| target = noise | |
| elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: | |
| pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss | |
| target = latents | |
| loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) | |
| loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps | |
| return loss, loss_adjusted | |
| def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): | |
| start_iter = self.info.iter+1 | |
| max_iters = self.config.updates * self.config.grad_accum_steps | |
| if self.is_main_node: | |
| print(f"STARTING AT STEP: {start_iter}/{max_iters}") | |
| pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP | |
| models.generator.train() | |
| for i in pbar: | |
| # FORWARD PASS | |
| loss, loss_adjusted = self.forward_pass(data, extras, models) | |
| # BACKWARD PASS | |
| if i % self.config.grad_accum_steps == 0 or i == max_iters: | |
| loss_adjusted.backward() | |
| grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) | |
| optimizers_dict = optimizers.to_dict() | |
| for k in optimizers_dict: | |
| optimizers_dict[k].step() | |
| schedulers_dict = schedulers.to_dict() | |
| for k in schedulers_dict: | |
| schedulers_dict[k].step() | |
| models.generator.zero_grad(set_to_none=True) | |
| self.info.total_steps += 1 | |
| else: | |
| with models.generator.no_sync(): | |
| loss_adjusted.backward() | |
| self.info.iter = i | |
| # UPDATE EMA | |
| if models.generator_ema is not None and i % self.config.ema_iters == 0: | |
| update_weights_ema( | |
| models.generator_ema, models.generator, | |
| beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) | |
| ) | |
| # UPDATE LOSS METRICS | |
| self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 | |
| if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): | |
| wandb.alert( | |
| title=f"NaN value encountered in training run {self.info.wandb_run_id}", | |
| text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", | |
| wait_duration=60*30 | |
| ) | |
| if self.is_main_node: | |
| logs = { | |
| 'loss': self.info.ema_loss, | |
| 'raw_loss': loss.mean().item(), | |
| 'grad_norm': grad_norm.item(), | |
| 'lr': optimizers.generator.param_groups[0]['lr'], | |
| 'total_steps': self.info.total_steps, | |
| } | |
| pbar.set_postfix(logs) | |
| if self.config.wandb_project is not None: | |
| wandb.log(logs) | |
| if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: | |
| # SAVE AND CHECKPOINT STUFF | |
| if np.isnan(loss.mean().item()): | |
| if self.is_main_node and self.config.wandb_project is not None: | |
| tqdm.write("Skipping sampling & checkpoint because the loss is NaN") | |
| wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") | |
| else: | |
| self.save_checkpoints(models, optimizers) | |
| if self.is_main_node: | |
| create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') | |
| self.sample(models, data, extras) | |
| def models_to_save(self): | |
| return ['generator', 'generator_ema'] | |
| def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): | |
| barrier() | |
| suffix = '' if suffix is None else suffix | |
| self.save_info(self.info, suffix=suffix) | |
| models_dict = models.to_dict() | |
| optimizers_dict = optimizers.to_dict() | |
| for key in self.models_to_save(): | |
| model = models_dict[key] | |
| if model is not None: | |
| self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) | |
| for key in optimizers_dict: | |
| optimizer = optimizers_dict[key] | |
| if optimizer is not None: | |
| self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) | |
| if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: | |
| self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") | |
| torch.cuda.empty_cache() | |