Spaces:
Paused
Paused
| import os | |
| import copy | |
| import json | |
| import time | |
| import torch | |
| import argparse | |
| import soundfile as sf | |
| from tqdm import tqdm | |
| from diffusers import DDPMScheduler | |
| from audioldm_eval import EvaluationHelper | |
| from models import build_pretrained_models, AudioDiffusion | |
| from transformers import AutoProcessor, ClapModel | |
| import torchaudio | |
| from tango import Tango | |
| class dotdict(dict): | |
| """dot.notation access to dictionary attributes""" | |
| __getattr__ = dict.get | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| def chunks(lst, n): | |
| """Yield successive n-sized chunks from lst.""" | |
| for i in range(0, len(lst), n): | |
| yield lst[i:i + n] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Inference for text to audio generation task.") | |
| parser.add_argument( | |
| "--checkpoint", type=str, default="declare-lab/tango", | |
| help="Tango huggingface checkpoint" | |
| ) | |
| parser.add_argument( | |
| "--test_file", type=str, default="data/test_audiocaps_subset.json", | |
| help="json file containing the test prompts for generation." | |
| ) | |
| parser.add_argument( | |
| "--text_key", type=str, default="captions", | |
| help="Key containing the text in the json file." | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda:0", | |
| help="Device to use for inference." | |
| ) | |
| parser.add_argument( | |
| "--test_references", type=str, default="data/audiocaps_test_references/subset", | |
| help="Folder containing the test reference wav files." | |
| ) | |
| parser.add_argument( | |
| "--num_steps", type=int, default=200, | |
| help="How many denoising steps for generation.", | |
| ) | |
| parser.add_argument( | |
| "--guidance", type=float, default=3, | |
| help="Guidance scale for classifier free guidance." | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=8, | |
| help="Batch size for generation.", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| num_steps, guidance, batch_size = args.num_steps, args.guidance, args.batch_size | |
| checkpoint = args.checkpoint | |
| # Load Models # | |
| tango = Tango(checkpoint, args.device) | |
| vae, stft, model = tango.vae, tango.stft, tango.model | |
| scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler") | |
| evaluator = EvaluationHelper(16000, "cuda:0") | |
| # Load Data # | |
| prefix = "" | |
| text_prompts = [json.loads(line)[args.text_key] for line in open(args.test_file).readlines()] | |
| text_prompts = [prefix + inp for inp in text_prompts] | |
| exp_id = str(int(time.time())) | |
| if not os.path.exists("outputs"): | |
| os.makedirs("outputs") | |
| output_dir = "outputs/{}_steps_{}_guidance_{}".format(exp_id, num_steps, guidance) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Generate # | |
| all_outputs = [] | |
| for k in tqdm(range(0, len(text_prompts), batch_size)): | |
| text = text_prompts[k: k+batch_size] | |
| with torch.no_grad(): | |
| latents = model.inference(text, scheduler, num_steps, guidance) | |
| mel = vae.decode_first_stage(latents) | |
| wave = vae.decode_to_waveform(mel) | |
| all_outputs += [item for item in wave] | |
| # Save # | |
| for j, wav in enumerate(all_outputs): | |
| sf.write("{}/output_{}.wav".format(output_dir, j), wav, samplerate=16000) | |
| result = evaluator.main(output_dir, args.test_references) | |
| result["Steps"] = num_steps | |
| result["Guidance Scale"] = guidance | |
| result["Test Instances"] = len(text_prompts) | |
| result["scheduler_config"] = dict(scheduler.config) | |
| result["args"] = dict(vars(args)) | |
| result["output_dir"] = output_dir | |
| with open("outputs/tango_checkpoint_summary.jsonl", "a") as f: | |
| f.write(json.dumps(result) + "\n\n") | |
| if __name__ == "__main__": | |
| main() |