Spaces:
Runtime error
Runtime error
| import argparse | |
| from datetime import datetime | |
| import random | |
| import os | |
| import time | |
| import multiprocessing | |
| # Set multiprocessing start method to 'spawn' to avoid CUDA initialization issues in forked processes | |
| multiprocessing.set_start_method('spawn', force=True) | |
| from tqdm.auto import tqdm # Progress bar | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR, ExponentialLR # Importing CosineAnnealingLR scheduler | |
| import torch.nn.functional as F | |
| from accelerate import Accelerator, DistributedDataParallelKwargs | |
| from accelerate.utils import set_seed # Removed get_scheduler import | |
| from peft import get_peft_model, LoraConfig | |
| from modeling import VMemModel | |
| from modeling.modules.autoencoder import AutoEncoder | |
| from modeling.sampling import DDPMDiscretization, DiscreteDenoiser, create_samplers | |
| from modeling.modules.conditioner import CLIPConditioner | |
| from utils.training_utils import DiffusionTrainer, load_pretrained_model | |
| from data.dataset import RealEstatePoseImageSevaDataset | |
| # set random seed for reproducibility | |
| torch.manual_seed(42) | |
| random.seed(42) | |
| np.random.seed(42) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Train a model') | |
| parser.add_argument('--config', type=str, default="", required=True, help='Path to the config file') | |
| args = parser.parse_args() | |
| return args | |
| def generate_current_datetime(): | |
| return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| def prepare_model(unet, config): | |
| assert isinstance(unet, VMemModel), "unet should be an instance of VMemModel" | |
| if config.training.lora_flag: | |
| target_modules = [] | |
| for name, param in unet.named_parameters(): | |
| # # if ("temporal" in name or "transformer" in name) and "norm" not in name: | |
| print(name) | |
| if ("transformer" in name or "emb" in name or "layers" in name) \ | |
| and "norm" not in name and "in_layers.0" not in name and "out_layers.0" not in name: | |
| # print(name) | |
| name = name.replace(".weight", "") | |
| name = name.replace(".bias", "") | |
| if name not in target_modules: | |
| target_modules.append(str(name)) | |
| lora_config = LoraConfig( | |
| r=config.training.lora_r, | |
| lora_alpha=config.training.lora_alpha, | |
| target_modules=target_modules, | |
| lora_dropout=config.training.lora_dropout, | |
| # bias="none", | |
| ) | |
| lora_config.target_modules = target_modules | |
| unet = get_peft_model(unet, lora_config) | |
| # for name, param in unet.named_parameters(): | |
| # if "camera" in name or "control" in name or "context" in name or "epipolar" in name or "appearance" in name: | |
| # print(name) | |
| # param.requires_grad = True | |
| unet.print_trainable_parameters() | |
| else: | |
| for name, param in unet.named_parameters(): | |
| param.requires_grad = True | |
| print("trainable parameters percentage: ", np.sum([p.numel() for p in unet.parameters() if p.requires_grad])/np.sum([p.numel() for p in unet.parameters()])) | |
| return unet | |
| def main(): | |
| args = parse_args() | |
| config_path = args.config | |
| config = OmegaConf.load(config_path) | |
| # Load the configuration | |
| num_epochs = config.training.num_epochs | |
| batch_size = config.training.batch_size | |
| learning_rate = config.training.learning_rate | |
| gradient_accumulation_steps = config.training.gradient_accumulation_steps | |
| num_workers = config.training.num_workers | |
| warmup_epochs = config.training.warmup_epochs | |
| max_grad_norm = config.training.max_grad_norm | |
| validation_interval = config.training.validation_interval | |
| visualization_flag = config.training.visualization_flag | |
| visualize_every = config.training.visualize_every | |
| random_seed = config.training.random_seed | |
| save_flag = config.training.save_flag | |
| use_wandb = config.training.use_wandb | |
| samples_dir = config.training.samples_dir | |
| weights_save_dir = config.training.weights_save_dir | |
| resume = config.training.resume | |
| exp_id = generate_current_datetime() | |
| if visualization_flag: | |
| run_visualization_dir = f"{samples_dir}/{exp_id}" | |
| os.makedirs(run_visualization_dir, exist_ok=True) | |
| else: | |
| run_visualization_dir = None | |
| if save_flag: | |
| run_weights_save_dir = f"{weights_save_dir}/{exp_id}" | |
| os.makedirs(run_weights_save_dir, exist_ok=True) | |
| else: | |
| run_weights_save_dir = None | |
| accelerator = Accelerator( | |
| mixed_precision="fp16", | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=False)], | |
| ) | |
| num_gpus = accelerator.num_processes | |
| if random_seed is not None: | |
| set_seed(random_seed, device_specific=True) | |
| device = accelerator.device | |
| model = load_pretrained_model(cache_dir=config.model.cache_dir, device=device) | |
| model = prepare_model(model, config) | |
| if resume: | |
| model.load_state_dict(torch.load(resume, map_location='cpu'), strict=False) | |
| torch.cuda.empty_cache() | |
| # model = model.to(device) | |
| # time.sleep(100*3600) | |
| train_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir, | |
| meta_info_dir=config.dataset.realestate10k.meta_info_dir, | |
| num_sample_per_episode=config.dataset.realestate10k.num_sample_per_episode, | |
| mode='train') | |
| val_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir, | |
| meta_info_dir=config.dataset.realestate10k.meta_info_dir, | |
| num_sample_per_episode=config.dataset.realestate10k.val_num_sample_per_episode, | |
| mode='test') | |
| train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn') | |
| val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn') | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=config.training.weight_decay) | |
| train_steps_per_epoch = len(train_dataloader) | |
| total_train_steps = num_epochs * train_steps_per_epoch | |
| warmup_steps = warmup_epochs * train_steps_per_epoch | |
| lr_scheduler = CosineAnnealingLR( | |
| optimizer, T_max=total_train_steps - warmup_steps, eta_min=0 | |
| ) | |
| # lr_scheduler = ExponentialLR(optimizer, gamma=gamma) | |
| if warmup_epochs > 0: | |
| def warmup_lambda(current_step): | |
| return float(current_step) / float(max(1, warmup_steps)) | |
| warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda) | |
| # Combine the schedulers using SequentialLR | |
| lr_scheduler = SequentialLR( | |
| optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps] | |
| ) | |
| vae = AutoEncoder(chunk_size=1).to(device) | |
| vae.eval() | |
| conditioner = CLIPConditioner().to(device) | |
| discretization = DDPMDiscretization() | |
| denoiser = DiscreteDenoiser(discretization=discretization, num_idx=1000, device=device) | |
| sampler = create_samplers(guider_types=config.training.guider_types, | |
| discretization=discretization, | |
| num_frames=config.model.num_frames, | |
| num_steps=config.training.inference_num_steps, | |
| cfg_min=config.training.cfg_min, | |
| device=device) | |
| (model, | |
| vae, | |
| train_dataloader, | |
| val_dataloader, | |
| optimizer, | |
| lr_scheduler) = accelerator.prepare( | |
| model, | |
| vae, | |
| train_dataloader, | |
| val_dataloader, | |
| optimizer, | |
| lr_scheduler, | |
| ) | |
| trainer = DiffusionTrainer(network=model, | |
| ae=vae, | |
| conditioner=conditioner, | |
| denoiser=denoiser, | |
| sampler=sampler, | |
| discretization=discretization, | |
| cfg=config.training.cfg, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| ema_decay=config.training.ema_decay, | |
| device=device, | |
| accelerator=accelerator, | |
| max_grad_norm=max_grad_norm, | |
| save_flag=save_flag, | |
| visualize_flag=visualization_flag) | |
| trainer.train(train_dataloader, | |
| num_epochs, | |
| unconditional_prob=config.training.uncond_prob, | |
| log_every=10, | |
| validation_dataloader=val_dataloader, | |
| validation_interval=validation_interval, | |
| save_dir=run_weights_save_dir, | |
| save_interval=config.training.save_every, | |
| visualize_every=visualize_every, | |
| visualize_dir=run_visualization_dir, | |
| use_wandb=use_wandb) | |
| if __name__ == "__main__": | |
| main() | |