import sys # sys.path.append("src") import shutil import os os.environ["TOKENIZERS_PARALLELISM"] = "true" import argparse import yaml import torch from tqdm import tqdm from pytorch_lightning.strategies.ddp import DDPStrategy from qa_mdt.audioldm_train.modules.latent_diffusion.ddpm import LatentDiffusion from torch.utils.data import WeightedRandomSampler from torch.utils.data import DataLoader from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from qa_mdt.audioldm_train.utilities.tools import ( listdir_nohidden, get_restore_step, copy_test_subset_data, ) import wandb from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config import logging logging.basicConfig(level=logging.WARNING) def convert_path(path): parts = path.decode().split("/")[-4:] base = "" result = "/".join(parts) def print_on_rank0(msg): if torch.distributed.get_rank() == 0: print(msg) def main(configs, config_yaml_path, exp_group_name, exp_name, perform_validation): print("MAIN START") # cpth = "/train20/intern/permanent/changli7/dataset_ptm/test_dataset/dataset/audioset/zip_audios/unbalanced_train_segments/unbalanced_train_segments_part9/Y7fmOlUlwoNg.wav" # convert_path(cpth) if "seed" in configs.keys(): seed_everything(configs["seed"]) else: print("SEED EVERYTHING TO 0") seed_everything(1234) if "precision" in configs.keys(): torch.set_float32_matmul_precision( configs["precision"] ) # highest, high, medium log_path = configs["log_directory"] batch_size = configs["model"]["params"]["batchsize"] train_lmdb_path = configs["train_path"]["train_lmdb_path"] train_key_path = [_ + '/data_key.key' for _ in train_lmdb_path] val_lmdb_path = configs["val_path"]["val_lmdb_path"] val_key_path = configs["val_path"]["val_key_path"] #try: mos_path = configs["mos_path"] from qa_mdt.audioldm_train.utilities.data.hhhh import AudioDataset dataset = AudioDataset(config=configs, lmdb_path=train_lmdb_path, key_path=train_key_path, mos_path=mos_path) loader = DataLoader( dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True, ) print( "The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" % (len(dataset), len(loader), batch_size) ) try: val_dataset = AudioDataset(config=configs, lmdb_path=val_lmdb_path, key_path=val_key_path, mos_path=mos_path) except: val_dataset = AudioDataset(config=configs, lmdb_path=val_lmdb_path, key_path=val_key_path) val_loader = DataLoader( val_dataset, batch_size=8, ) # Copy test data import os test_data_subset_folder = os.path.join( os.path.dirname(configs["log_directory"]), "testset_data", "tmp", ) os.makedirs(test_data_subset_folder, exist_ok=True) # copy to test: # import pdb # pdb.set_trace() # for i in range(len(val_dataset.keys)): # key_tmp = val_dataset.keys[i].decode() # cmd = "cp {} {}".format(key_tmp, os.path.join(test_data_subset_folder)) # os.system(cmd) try: config_reload_from_ckpt = configs["reload_from_ckpt"] except: config_reload_from_ckpt = None try: limit_val_batches = configs["step"]["limit_val_batches"] except: limit_val_batches = None validation_every_n_epochs = configs["step"]["validation_every_n_epochs"] save_checkpoint_every_n_steps = configs["step"]["save_checkpoint_every_n_steps"] max_steps = configs["step"]["max_steps"] save_top_k = configs["step"]["save_top_k"] checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") wandb_path = os.path.join(log_path, exp_group_name, exp_name) checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_path, monitor="global_step", mode="max", filename="checkpoint-fad-{val/frechet_inception_distance:.2f}-global_step={global_step:.0f}", every_n_train_steps=save_checkpoint_every_n_steps, save_top_k=save_top_k, auto_insert_metric_name=False, save_last=False, ) os.makedirs(checkpoint_path, exist_ok=True) # shutil.copy(config_yaml_path, wandb_path) if len(os.listdir(checkpoint_path)) > 0: print("Load checkpoint from path: %s" % checkpoint_path) restore_step, n_step = get_restore_step(checkpoint_path) resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) print("Resume from checkpoint", resume_from_checkpoint) elif config_reload_from_ckpt is not None: resume_from_checkpoint = config_reload_from_ckpt print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) else: print("Train from scratch") resume_from_checkpoint = None devices = torch.cuda.device_count() latent_diffusion = instantiate_from_config(configs["model"]) latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name) wandb_logger = WandbLogger( save_dir=wandb_path, project=configs["project"], config=configs, name="%s/%s" % (exp_group_name, exp_name), ) latent_diffusion.test_data_subset_path = test_data_subset_folder print("==> Save checkpoint every %s steps" % save_checkpoint_every_n_steps) print("==> Perform validation every %s epochs" % validation_every_n_epochs) trainer = Trainer( accelerator="auto", devices="auto", logger=wandb_logger, max_steps=max_steps, num_sanity_val_steps=1, limit_val_batches=limit_val_batches, check_val_every_n_epoch=validation_every_n_epochs, strategy=DDPStrategy(find_unused_parameters=True), gradient_clip_val=2.0,callbacks=[checkpoint_callback],num_nodes=1, ) trainer.fit(latent_diffusion, loader, val_loader, ckpt_path=resume_from_checkpoint) ################################################################################################################ # if(resume_from_checkpoint is not None): # ckpt = torch.load(resume_from_checkpoint)["state_dict"] # key_not_in_model_state_dict = [] # size_mismatch_keys = [] # state_dict = latent_diffusion.state_dict() # print("Filtering key for reloading:", resume_from_checkpoint) # print("State dict key size:", len(list(state_dict.keys())), len(list(ckpt.keys()))) # for key in tqdm(list(ckpt.keys())): # if(key not in state_dict.keys()): # key_not_in_model_state_dict.append(key) # del ckpt[key] # continue # if(state_dict[key].size() != ckpt[key].size()): # del ckpt[key] # size_mismatch_keys.append(key) # if(len(key_not_in_model_state_dict) != 0 or len(size_mismatch_keys) != 0): # print("⛳", end=" ") # print("==> Warning: The following key in the checkpoint is not presented in the model:", key_not_in_model_state_dict) # print("==> Warning: These keys have different size between checkpoint and current model: ", size_mismatch_keys) # latent_diffusion.load_state_dict(ckpt, strict=False) # if(perform_validation): # trainer.validate(latent_diffusion, val_loader) # trainer.fit(latent_diffusion, loader, val_loader) ################################################################################################################ if __name__ == "__main__": print("ok") parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config_yaml", type=str, required=False, help="path to config .yaml file", ) parser.add_argument("--val", action="store_true") args = parser.parse_args() perform_validation = args.val assert torch.cuda.is_available(), "CUDA is not available" config_yaml = args.config_yaml exp_name = os.path.basename(config_yaml.split(".")[0]) exp_group_name = os.path.basename(os.path.dirname(config_yaml)) config_yaml_path = os.path.join(config_yaml) config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) if perform_validation: config_yaml["model"]["params"]["cond_stage_config"][ "crossattn_audiomae_generated" ]["params"]["use_gt_mae_output"] = False config_yaml["step"]["limit_val_batches"] = None main(config_yaml, config_yaml_path, exp_group_name, exp_name, perform_validation)