|
import shutil |
|
import os |
|
|
|
import argparse |
|
import yaml |
|
import torch |
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
from qa_mdt.audioldm_train.utilities.data.dataset_original_mos5 import AudioDataset as AudioDataset |
|
from qa_mdt.audioldm_train.utilities.tools import build_dataset_json_from_list |
|
from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config |
|
from torch.utils.data import DataLoader |
|
from pytorch_lightning import seed_everything |
|
from qa_mdt.audioldm_train.utilities.tools import get_restore_step |
|
|
|
|
|
def infer(dataset_key, configs, config_yaml_path, exp_group_name, exp_name): |
|
|
|
seed_everything(0) |
|
|
|
if "precision" in configs.keys(): |
|
torch.set_float32_matmul_precision(configs["precision"]) |
|
|
|
log_path = configs["log_directory"] |
|
if "dataloader_add_ons" in configs["data"].keys(): |
|
dataloader_add_ons = configs["data"]["dataloader_add_ons"] |
|
else: |
|
dataloader_add_ons = [] |
|
val_dataset = AudioDataset( |
|
configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_key |
|
) |
|
|
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=1, |
|
) |
|
|
|
try: |
|
config_reload_from_ckpt = configs["reload_from_ckpt"] |
|
except: |
|
config_reload_from_ckpt = None |
|
|
|
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") |
|
|
|
wandb_path = os.path.join(log_path, exp_group_name, exp_name) |
|
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
shutil.copy(config_yaml_path, wandb_path) |
|
|
|
if len(os.listdir(checkpoint_path)) > 0: |
|
print("Load checkpoint from path: %s" % checkpoint_path) |
|
restore_step, n_step = get_restore_step(checkpoint_path) |
|
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) |
|
print("Resume from checkpoint", resume_from_checkpoint) |
|
elif config_reload_from_ckpt is not None: |
|
resume_from_checkpoint = config_reload_from_ckpt |
|
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) |
|
else: |
|
print("Train from scratch") |
|
resume_from_checkpoint = None |
|
|
|
latent_diffusion = instantiate_from_config(configs["model"]) |
|
latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name) |
|
|
|
guidance_scale = configs["model"]["params"]["evaluation_params"][ |
|
"unconditional_guidance_scale" |
|
] |
|
ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][ |
|
"ddim_sampling_steps" |
|
] |
|
n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][ |
|
"n_candidates_per_samples" |
|
] |
|
|
|
checkpoint = torch.load(resume_from_checkpoint) |
|
latent_diffusion.load_state_dict(checkpoint["state_dict"],strict=False) |
|
|
|
latent_diffusion.eval() |
|
latent_diffusion = latent_diffusion.cuda() |
|
|
|
latent_diffusion.generate_sample( |
|
val_loader, |
|
unconditional_guidance_scale=guidance_scale, |
|
ddim_steps=ddim_sampling_steps, |
|
n_gen=n_candidates_per_samples, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"-c", |
|
"--config_yaml", |
|
type=str, |
|
required=False, |
|
help="path to config .yaml file", |
|
) |
|
|
|
parser.add_argument( |
|
"-l", |
|
"--list_inference", |
|
type=str, |
|
required=False, |
|
help="The filelist that contain captions (and optionally filenames)", |
|
) |
|
|
|
parser.add_argument( |
|
"-reload_from_ckpt", |
|
"--reload_from_ckpt", |
|
type=str, |
|
required=False, |
|
default=None, |
|
help="the checkpoint path for the model", |
|
) |
|
|
|
args = parser.parse_args() |
|
assert torch.cuda.is_available(), "CUDA is not available" |
|
|
|
config_yaml = args.config_yaml |
|
dataset_key = build_dataset_json_from_list(args.list_inference) |
|
exp_name = os.path.basename(config_yaml.split(".")[0]) |
|
exp_group_name = os.path.basename(os.path.dirname(config_yaml)) |
|
|
|
config_yaml_path = os.path.join(config_yaml) |
|
config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) |
|
|
|
if args.reload_from_ckpt is not None: |
|
config_yaml["reload_from_ckpt"] = args.reload_from_ckpt |
|
|
|
infer(dataset_key, config_yaml, config_yaml_path, exp_group_name, exp_name) |
|
|